Optimize ‘hist’ for multi-core CPU (#4529)
* Initial performance optimizations for xgboost * remove includes * revert float->double * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * Check existence of _mm_prefetch and __builtin_prefetch * Fix lint * optimizations for CPU * appling comments in review * add some comments, code refactoring * fixing issues in CI * adding runtime checks * remove 1 extra check * remove extra checks in BuildHist * remove checks * add debug info * added debug info * revert changes * added comments * Apply suggestions from code review Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu> * apply review comments * Remove unused function CreateNewNodes() * Add descriptive comment on node_idx variable in QuantileHistMaker::Builder::BuildHistsBatch()
This commit is contained in:
parent
abffbe014e
commit
4d6590be3c
@ -8,11 +8,11 @@
|
|||||||
#ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_
|
#ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_
|
||||||
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
|
#define XGBOOST_COMMON_COLUMN_MATRIX_H_
|
||||||
|
|
||||||
|
#include <dmlc/timer.h>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "hist_util.h"
|
#include "hist_util.h"
|
||||||
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
@ -51,6 +51,10 @@ class Column {
|
|||||||
}
|
}
|
||||||
const size_t* GetRowData() const { return row_ind_; }
|
const size_t* GetRowData() const { return row_ind_; }
|
||||||
|
|
||||||
|
const uint32_t* GetIndex() const {
|
||||||
|
return index_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ColumnType type_;
|
ColumnType type_;
|
||||||
const uint32_t* index_;
|
const uint32_t* index_;
|
||||||
@ -80,7 +84,7 @@ class ColumnMatrix {
|
|||||||
std::fill(feature_counts_.begin(), feature_counts_.end(), 0);
|
std::fill(feature_counts_.begin(), feature_counts_.end(), 0);
|
||||||
|
|
||||||
uint32_t max_val = std::numeric_limits<uint32_t>::max();
|
uint32_t max_val = std::numeric_limits<uint32_t>::max();
|
||||||
for (bst_uint fid = 0; fid < nfeature; ++fid) {
|
for (int32_t fid = 0; fid < nfeature; ++fid) {
|
||||||
CHECK_LE(gmat.cut.row_ptr[fid + 1] - gmat.cut.row_ptr[fid], max_val);
|
CHECK_LE(gmat.cut.row_ptr[fid + 1] - gmat.cut.row_ptr[fid], max_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,13 +117,12 @@ class ColumnMatrix {
|
|||||||
boundary_[fid].index_end = accum_index_;
|
boundary_[fid].index_end = accum_index_;
|
||||||
boundary_[fid].row_ind_end = accum_row_ind_;
|
boundary_[fid].row_ind_end = accum_row_ind_;
|
||||||
}
|
}
|
||||||
|
|
||||||
index_.resize(boundary_[nfeature - 1].index_end);
|
index_.resize(boundary_[nfeature - 1].index_end);
|
||||||
row_ind_.resize(boundary_[nfeature - 1].row_ind_end);
|
row_ind_.resize(boundary_[nfeature - 1].row_ind_end);
|
||||||
|
|
||||||
// store least bin id for each feature
|
// store least bin id for each feature
|
||||||
index_base_.resize(nfeature);
|
index_base_.resize(nfeature);
|
||||||
for (bst_uint fid = 0; fid < nfeature; ++fid) {
|
for (int32_t fid = 0; fid < nfeature; ++fid) {
|
||||||
index_base_[fid] = gmat.cut.row_ptr[fid];
|
index_base_[fid] = gmat.cut.row_ptr[fid];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,15 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017-2019 by Contributors
|
* Copyright 2017-2019 by Contributors
|
||||||
* \file hist_util.h
|
* \file hist_util.cc
|
||||||
*/
|
*/
|
||||||
|
#include "./hist_util.h"
|
||||||
|
#include <dmlc/timer.h>
|
||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
#include <dmlc/omp.h>
|
#include <dmlc/omp.h>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "./random.h"
|
#include "./random.h"
|
||||||
#include "./column_matrix.h"
|
#include "./column_matrix.h"
|
||||||
#include "./hist_util.h"
|
|
||||||
#include "./quantile.h"
|
#include "./quantile.h"
|
||||||
#include "./../tree/updater_quantile_hist.h"
|
#include "./../tree/updater_quantile_hist.h"
|
||||||
|
|
||||||
@ -178,7 +178,7 @@ uint32_t HistCutMatrix::GetBinIdx(const Entry& e) {
|
|||||||
|
|
||||||
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
||||||
cut.Init(p_fmat, max_num_bins);
|
cut.Init(p_fmat, max_num_bins);
|
||||||
const size_t nthread = omp_get_max_threads();
|
const int32_t nthread = omp_get_max_threads();
|
||||||
const uint32_t nbins = cut.row_ptr.back();
|
const uint32_t nbins = cut.row_ptr.back();
|
||||||
hit_count.resize(nbins, 0);
|
hit_count.resize(nbins, 0);
|
||||||
hit_count_tloc_.resize(nthread * nbins, 0);
|
hit_count_tloc_.resize(nthread * nbins, 0);
|
||||||
@ -260,8 +260,8 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
#pragma omp parallel for num_threads(nthread) schedule(static)
|
||||||
for (bst_omp_uint idx = 0; idx < bst_omp_uint(nbins); ++idx) {
|
for (int32_t idx = 0; idx < int32_t(nbins); ++idx) {
|
||||||
for (size_t tid = 0; tid < nthread; ++tid) {
|
for (int32_t tid = 0; tid < nthread; ++tid) {
|
||||||
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
hit_count[idx] += hit_count_tloc_[tid * nbins + idx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -411,7 +411,7 @@ FastFeatureGrouping(const GHistIndexMatrix& gmat,
|
|||||||
for (auto fid : group) {
|
for (auto fid : group) {
|
||||||
nnz += feature_nnz[fid];
|
nnz += feature_nnz[fid];
|
||||||
}
|
}
|
||||||
double nnz_rate = static_cast<double>(nnz) / nrow;
|
float nnz_rate = static_cast<float>(nnz) / nrow;
|
||||||
// take apart small sparse group, due it will not gain on speed
|
// take apart small sparse group, due it will not gain on speed
|
||||||
if (nnz_rate <= param.sparse_threshold) {
|
if (nnz_rate <= param.sparse_threshold) {
|
||||||
for (auto fid : group) {
|
for (auto fid : group) {
|
||||||
@ -496,177 +496,145 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
|
// used when data layout is kDenseDataZeroBased or kDenseDataOneBased
|
||||||
const RowSetCollection::Elem row_indices,
|
// it means that "row_ptr" is not needed for hist computations
|
||||||
const GHistIndexMatrix& gmat,
|
void BuildHistLocalDense(size_t istart, size_t iend, size_t nrows, const size_t* rid,
|
||||||
GHistRow hist) {
|
const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr,
|
||||||
const size_t nthread = static_cast<size_t>(this->nthread_);
|
GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat_global) {
|
||||||
data_.resize(nbins_ * nthread_);
|
GradStatHist grad_stat; // make local var to prevent false sharing
|
||||||
|
|
||||||
const size_t* rid = row_indices.begin;
|
|
||||||
const size_t nrows = row_indices.Size();
|
|
||||||
const uint32_t* index = gmat.index.data();
|
|
||||||
const size_t* row_ptr = gmat.row_ptr.data();
|
|
||||||
const float* pgh = reinterpret_cast<const float*>(gpair.data());
|
|
||||||
|
|
||||||
double* hist_data = reinterpret_cast<double*>(hist.data());
|
|
||||||
double* data = reinterpret_cast<double*>(data_.data());
|
|
||||||
|
|
||||||
const size_t block_size = 512;
|
|
||||||
size_t n_blocks = nrows/block_size;
|
|
||||||
n_blocks += !!(nrows - n_blocks*block_size);
|
|
||||||
|
|
||||||
const size_t nthread_to_process = std::min(nthread, n_blocks);
|
|
||||||
memset(thread_init_.data(), '\0', nthread_to_process*sizeof(size_t));
|
|
||||||
|
|
||||||
|
const size_t n_features = row_ptr[rid[istart]+1] - row_ptr[rid[istart]];
|
||||||
const size_t cache_line_size = 64;
|
const size_t cache_line_size = 64;
|
||||||
|
const size_t prefetch_step = cache_line_size / sizeof(*index);
|
||||||
const size_t prefetch_offset = 10;
|
const size_t prefetch_offset = 10;
|
||||||
|
|
||||||
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
|
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
|
||||||
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;
|
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;
|
||||||
|
|
||||||
#pragma omp parallel for num_threads(nthread_to_process) schedule(guided)
|
// if read each row in some block of bin-matrix - it's dense block
|
||||||
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
|
// and we dont need SW prefetch in this case
|
||||||
dmlc::omp_uint tid = omp_get_thread_num();
|
const bool denseBlock = (rid[iend-1] - rid[istart]) == (iend - istart - 1);
|
||||||
double* data_local_hist = ((nthread_to_process == 1) ? hist_data :
|
|
||||||
reinterpret_cast<double*>(data_.data() + tid * nbins_));
|
|
||||||
|
|
||||||
if (!thread_init_[tid]) {
|
if (iend < nrows - no_prefetch_size && !denseBlock) {
|
||||||
memset(data_local_hist, '\0', 2*nbins_*sizeof(double));
|
|
||||||
thread_init_[tid] = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t istart = iblock*block_size;
|
|
||||||
const size_t iend = (((iblock+1)*block_size > nrows) ? nrows : istart + block_size);
|
|
||||||
for (size_t i = istart; i < iend; ++i) {
|
for (size_t i = istart; i < iend; ++i) {
|
||||||
const size_t icol_start = row_ptr[rid[i]];
|
const size_t icol_start = rid[i] * n_features;
|
||||||
const size_t icol_end = row_ptr[rid[i]+1];
|
const size_t icol_start_prefetch = rid[i+prefetch_offset] * n_features;
|
||||||
|
|
||||||
if (i < nrows - no_prefetch_size) {
|
|
||||||
PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]);
|
|
||||||
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t j = icol_start; j < icol_end; ++j) {
|
|
||||||
const uint32_t idx_bin = 2*index[j];
|
|
||||||
const size_t idx_gh = 2*rid[i];
|
const size_t idx_gh = 2*rid[i];
|
||||||
|
|
||||||
|
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
|
||||||
|
|
||||||
|
for (size_t j = icol_start_prefetch; j < icol_start_prefetch + n_features;
|
||||||
|
j += prefetch_step) {
|
||||||
|
PREFETCH_READ_T0(index + j);
|
||||||
|
}
|
||||||
|
|
||||||
|
grad_stat.sum_grad += pgh[idx_gh];
|
||||||
|
grad_stat.sum_hess += pgh[idx_gh+1];
|
||||||
|
|
||||||
|
for (size_t j = icol_start; j < icol_start + n_features; ++j) {
|
||||||
|
const uint32_t idx_bin = 2*index[j];
|
||||||
|
data_local_hist[idx_bin] += pgh[idx_gh];
|
||||||
|
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = istart; i < iend; ++i) {
|
||||||
|
const size_t icol_start = rid[i] * n_features;
|
||||||
|
const size_t idx_gh = 2*rid[i];
|
||||||
|
grad_stat.sum_grad += pgh[idx_gh];
|
||||||
|
grad_stat.sum_hess += pgh[idx_gh+1];
|
||||||
|
|
||||||
|
for (size_t j = icol_start; j < icol_start + n_features; ++j) {
|
||||||
|
const uint32_t idx_bin = 2*index[j];
|
||||||
data_local_hist[idx_bin] += pgh[idx_gh];
|
data_local_hist[idx_bin] += pgh[idx_gh];
|
||||||
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
|
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
grad_stat_global->Add(grad_stat);
|
||||||
if (nthread_to_process > 1) {
|
|
||||||
const size_t size = (2*nbins_);
|
|
||||||
const size_t block_size = 1024;
|
|
||||||
size_t n_blocks = size/block_size;
|
|
||||||
n_blocks += !!(size - n_blocks*block_size);
|
|
||||||
|
|
||||||
size_t n_worked_bins = 0;
|
|
||||||
for (size_t i = 0; i < nthread_to_process; ++i) {
|
|
||||||
if (thread_init_[i]) {
|
|
||||||
thread_init_[n_worked_bins++] = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
|
|
||||||
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
|
|
||||||
const size_t istart = iblock * block_size;
|
|
||||||
const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size);
|
|
||||||
|
|
||||||
const size_t bin = 2 * thread_init_[0] * nbins_;
|
|
||||||
memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart));
|
|
||||||
|
|
||||||
for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) {
|
|
||||||
const size_t bin = 2 * thread_init_[i_bin_part] * nbins_;
|
|
||||||
for (size_t i = istart; i < iend; i++) {
|
|
||||||
hist_data[i] += data[bin + i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void GHistBuilder::BuildBlockHist(const std::vector<GradientPair>& gpair,
|
// used when data layout is kSparseData
|
||||||
const RowSetCollection::Elem row_indices,
|
// it means that "row_ptr" is needed for hist computations
|
||||||
const GHistIndexBlockMatrix& gmatb,
|
void BuildHistLocalSparse(size_t istart, size_t iend, size_t nrows, const size_t* rid,
|
||||||
GHistRow hist) {
|
const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr,
|
||||||
constexpr int kUnroll = 8; // loop unrolling factor
|
GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat_global) {
|
||||||
const size_t nblock = gmatb.GetNumBlock();
|
GradStatHist grad_stat; // make local var to prevent false sharing
|
||||||
const size_t nrows = row_indices.end - row_indices.begin;
|
|
||||||
const size_t rest = nrows % kUnroll;
|
|
||||||
|
|
||||||
#if defined(_OPENMP)
|
const size_t cache_line_size = 64;
|
||||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); // NOLINT
|
const size_t prefetch_step = cache_line_size / sizeof(index[0]);
|
||||||
#endif // defined(_OPENMP)
|
const size_t prefetch_offset = 10;
|
||||||
tree::GradStats* p_hist = hist.data();
|
|
||||||
|
|
||||||
#pragma omp parallel for num_threads(nthread) schedule(guided)
|
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
|
||||||
for (bst_omp_uint bid = 0; bid < nblock; ++bid) {
|
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;
|
||||||
auto gmat = gmatb[bid];
|
|
||||||
|
|
||||||
for (size_t i = 0; i < nrows - rest; i += kUnroll) {
|
// if read each row in some block of bin-matrix - it's dense block
|
||||||
size_t rid[kUnroll];
|
// and we dont need SW prefetch in this case
|
||||||
size_t ibegin[kUnroll];
|
const bool denseBlock = (rid[iend-1] - rid[istart]) == (iend - istart);
|
||||||
size_t iend[kUnroll];
|
|
||||||
GradientPair stat[kUnroll];
|
|
||||||
|
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
if (iend < nrows - no_prefetch_size && !denseBlock) {
|
||||||
rid[k] = row_indices.begin[i + k];
|
for (size_t i = istart; i < iend; ++i) {
|
||||||
ibegin[k] = gmat.row_ptr[rid[k]];
|
const size_t icol_start = row_ptr[rid[i]];
|
||||||
iend[k] = gmat.row_ptr[rid[k] + 1];
|
const size_t icol_end = row_ptr[rid[i]+1];
|
||||||
stat[k] = gpair[rid[k]];
|
const size_t idx_gh = 2*rid[i];
|
||||||
|
|
||||||
|
const size_t icol_start10 = row_ptr[rid[i+prefetch_offset]];
|
||||||
|
const size_t icol_end10 = row_ptr[rid[i+prefetch_offset]+1];
|
||||||
|
|
||||||
|
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
|
||||||
|
|
||||||
|
for (size_t j = icol_start10; j < icol_end10; j+=prefetch_step) {
|
||||||
|
PREFETCH_READ_T0(index + j);
|
||||||
}
|
}
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
|
||||||
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
|
grad_stat.sum_grad += pgh[idx_gh];
|
||||||
const uint32_t bin = gmat.index[j];
|
grad_stat.sum_hess += pgh[idx_gh+1];
|
||||||
p_hist[bin].Add(stat[k]);
|
|
||||||
}
|
for (size_t j = icol_start; j < icol_end; ++j) {
|
||||||
}
|
const uint32_t idx_bin = 2*index[j];
|
||||||
}
|
data_local_hist[idx_bin] += pgh[idx_gh];
|
||||||
for (size_t i = nrows - rest; i < nrows; ++i) {
|
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
|
||||||
const size_t rid = row_indices.begin[i];
|
}
|
||||||
const size_t ibegin = gmat.row_ptr[rid];
|
}
|
||||||
const size_t iend = gmat.row_ptr[rid + 1];
|
} else {
|
||||||
const GradientPair stat = gpair[rid];
|
for (size_t i = istart; i < iend; ++i) {
|
||||||
for (size_t j = ibegin; j < iend; ++j) {
|
const size_t icol_start = row_ptr[rid[i]];
|
||||||
const uint32_t bin = gmat.index[j];
|
const size_t icol_end = row_ptr[rid[i]+1];
|
||||||
p_hist[bin].Add(stat);
|
const size_t idx_gh = 2*rid[i];
|
||||||
|
|
||||||
|
grad_stat.sum_grad += pgh[idx_gh];
|
||||||
|
grad_stat.sum_hess += pgh[idx_gh+1];
|
||||||
|
|
||||||
|
for (size_t j = icol_start; j < icol_end; ++j) {
|
||||||
|
const uint32_t idx_bin = 2*index[j];
|
||||||
|
data_local_hist[idx_bin] += pgh[idx_gh];
|
||||||
|
data_local_hist[idx_bin+1] += pgh[idx_gh+1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
grad_stat_global->Add(grad_stat);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GHistBuilder::SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
||||||
const uint32_t nbins = static_cast<bst_omp_uint>(nbins_);
|
GradStatHist* p_self = self.data();
|
||||||
constexpr int kUnroll = 8; // loop unrolling factor
|
GradStatHist* p_sibling = sibling.data();
|
||||||
const uint32_t rest = nbins % kUnroll;
|
GradStatHist* p_parent = parent.data();
|
||||||
|
|
||||||
#if defined(_OPENMP)
|
const size_t size = self.size();
|
||||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_); // NOLINT
|
CHECK_EQ(sibling.size(), size);
|
||||||
#endif // defined(_OPENMP)
|
CHECK_EQ(parent.size(), size);
|
||||||
tree::GradStats* p_self = self.data();
|
|
||||||
tree::GradStats* p_sibling = sibling.data();
|
|
||||||
tree::GradStats* p_parent = parent.data();
|
|
||||||
|
|
||||||
#pragma omp parallel for num_threads(nthread) schedule(static)
|
const size_t block_size = 1024; // aproximatly 1024 values per block
|
||||||
for (bst_omp_uint bin_id = 0;
|
size_t n_blocks = size/block_size + !!(size%block_size);
|
||||||
bin_id < static_cast<bst_omp_uint>(nbins - rest); bin_id += kUnroll) {
|
|
||||||
tree::GradStats pb[kUnroll];
|
#pragma omp parallel for
|
||||||
tree::GradStats sb[kUnroll];
|
for (int iblock = 0; iblock < n_blocks; ++iblock) {
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
const size_t ibegin = iblock*block_size;
|
||||||
pb[k] = p_parent[bin_id + k];
|
const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size);
|
||||||
}
|
for (bst_omp_uint bin_id = ibegin; bin_id < iend; bin_id++) {
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
|
||||||
sb[k] = p_sibling[bin_id + k];
|
|
||||||
}
|
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
|
||||||
p_self[bin_id + k].SetSubstract(pb[k], sb[k]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (uint32_t bin_id = nbins - rest; bin_id < nbins; ++bin_id) {
|
|
||||||
p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
|
p_self[bin_id].SetSubstract(p_parent[bin_id], p_sibling[bin_id]);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -11,13 +11,50 @@
|
|||||||
#include <xgboost/generic_parameters.h>
|
#include <xgboost/generic_parameters.h>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
#include "row_set.h"
|
#include "row_set.h"
|
||||||
#include "../tree/param.h"
|
#include "../tree/param.h"
|
||||||
#include "./quantile.h"
|
#include "./quantile.h"
|
||||||
#include "./timer.h"
|
#include "./timer.h"
|
||||||
#include "../include/rabit/rabit.h"
|
#include "../include/rabit/rabit.h"
|
||||||
|
#include "random.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated.
|
||||||
|
*/
|
||||||
|
template<typename T, size_t MaxStackSize>
|
||||||
|
class MemStackAllocator {
|
||||||
|
public:
|
||||||
|
explicit MemStackAllocator(size_t required_size): required_size_(required_size) {
|
||||||
|
}
|
||||||
|
|
||||||
|
T* Get() {
|
||||||
|
if (!ptr_) {
|
||||||
|
if (MaxStackSize >= required_size_) {
|
||||||
|
ptr_ = stack_mem_;
|
||||||
|
} else {
|
||||||
|
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
|
||||||
|
do_free_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
~MemStackAllocator() {
|
||||||
|
if (do_free_) free(ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
T* ptr_ = nullptr;
|
||||||
|
bool do_free_ = false;
|
||||||
|
size_t required_size_;
|
||||||
|
T stack_mem_[MaxStackSize];
|
||||||
|
};
|
||||||
|
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -134,9 +171,10 @@ using GHistIndexRow = Span<uint32_t const>;
|
|||||||
*/
|
*/
|
||||||
struct GHistIndexMatrix {
|
struct GHistIndexMatrix {
|
||||||
/*! \brief row pointer to rows by element position */
|
/*! \brief row pointer to rows by element position */
|
||||||
std::vector<size_t> row_ptr;
|
// std::vector<size_t> row_ptr;
|
||||||
|
SimpleArray<size_t> row_ptr;
|
||||||
/*! \brief The index data */
|
/*! \brief The index data */
|
||||||
std::vector<uint32_t> index;
|
SimpleArray<uint32_t> index;
|
||||||
/*! \brief hit count of each index */
|
/*! \brief hit count of each index */
|
||||||
std::vector<size_t> hit_count;
|
std::vector<size_t> hit_count;
|
||||||
/*! \brief The corresponding cuts */
|
/*! \brief The corresponding cuts */
|
||||||
@ -170,6 +208,11 @@ struct GHistIndexBlock {
|
|||||||
|
|
||||||
inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index)
|
inline GHistIndexBlock(const size_t* row_ptr, const uint32_t* index)
|
||||||
: row_ptr(row_ptr), index(index) {}
|
: row_ptr(row_ptr), index(index) {}
|
||||||
|
|
||||||
|
// get i-th row
|
||||||
|
inline GHistIndexRow operator[](size_t i) const {
|
||||||
|
return {&index[0] + row_ptr[i], detail::ptrdiff_t(row_ptr[i + 1] - row_ptr[i])};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class ColumnMatrix;
|
class ColumnMatrix;
|
||||||
@ -202,12 +245,63 @@ class GHistIndexBlockMatrix {
|
|||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief histogram of graident statistics for a single node.
|
* \brief used instead of GradStats to have float instead of double to reduce histograms
|
||||||
* Consists of multiple GradStats, each entry showing total graident statistics
|
* this improves performance by 10-30% and memory consumption for histograms by 2x
|
||||||
* for that particular bin
|
* accuracy in both cases is the same
|
||||||
* Uses global bin id so as to represent all features simultaneously
|
|
||||||
*/
|
*/
|
||||||
using GHistRow = Span<tree::GradStats>;
|
struct GradStatHist {
|
||||||
|
typedef float GradType;
|
||||||
|
/*! \brief sum gradient statistics */
|
||||||
|
GradType sum_grad;
|
||||||
|
/*! \brief sum hessian statistics */
|
||||||
|
GradType sum_hess;
|
||||||
|
|
||||||
|
GradStatHist() : sum_grad{0}, sum_hess{0} {
|
||||||
|
static_assert(sizeof(GradStatHist) == 8,
|
||||||
|
"Size of GradStatHist is not 8 bytes.");
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Add(const GradStatHist& b) {
|
||||||
|
sum_grad += b.sum_grad;
|
||||||
|
sum_hess += b.sum_hess;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Add(const tree::GradStats& b) {
|
||||||
|
sum_grad += b.sum_grad;
|
||||||
|
sum_hess += b.sum_hess;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Add(const GradientPair& p) {
|
||||||
|
this->Add(p.GetGrad(), p.GetHess());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void Add(const GradType& grad, const GradType& hess) {
|
||||||
|
sum_grad += grad;
|
||||||
|
sum_hess += hess;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline tree::GradStats ToGradStat() const {
|
||||||
|
return tree::GradStats(sum_grad, sum_hess);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void SetSubstract(const GradStatHist& a, const GradStatHist& b) {
|
||||||
|
sum_grad = a.sum_grad - b.sum_grad;
|
||||||
|
sum_hess = a.sum_hess - b.sum_hess;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void SetSubstract(const tree::GradStats& a, const GradStatHist& b) {
|
||||||
|
sum_grad = a.sum_grad - b.sum_grad;
|
||||||
|
sum_hess = a.sum_hess - b.sum_hess;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline GradType GetGrad() const { return sum_grad; }
|
||||||
|
inline GradType GetHess() const { return sum_hess; }
|
||||||
|
inline static void Reduce(GradStatHist& a, const GradStatHist& b) { // NOLINT(*)
|
||||||
|
a.Add(b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using GHistRow = Span<GradStatHist>;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief histogram of gradient statistics for multiple nodes
|
* \brief histogram of gradient statistics for multiple nodes
|
||||||
@ -215,49 +309,43 @@ using GHistRow = Span<tree::GradStats>;
|
|||||||
class HistCollection {
|
class HistCollection {
|
||||||
public:
|
public:
|
||||||
// access histogram for i-th node
|
// access histogram for i-th node
|
||||||
GHistRow operator[](bst_uint nid) const {
|
inline GHistRow operator[](bst_uint nid) {
|
||||||
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
AddHistRow(nid);
|
||||||
CHECK_NE(row_ptr_[nid], kMax);
|
return { const_cast<GradStatHist*>(dmlc::BeginPtr(data_arr_[nid])), nbins_};
|
||||||
tree::GradStats* ptr =
|
|
||||||
const_cast<tree::GradStats*>(dmlc::BeginPtr(data_) + row_ptr_[nid]);
|
|
||||||
return {ptr, nbins_};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// have we computed a histogram for i-th node?
|
// have we computed a histogram for i-th node?
|
||||||
bool RowExists(bst_uint nid) const {
|
inline bool RowExists(bst_uint nid) const {
|
||||||
const uint32_t k_max = std::numeric_limits<uint32_t>::max();
|
return nid < data_arr_.size();
|
||||||
return (nid < row_ptr_.size() && row_ptr_[nid] != k_max);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize histogram collection
|
// initialize histogram collection
|
||||||
void Init(uint32_t nbins) {
|
inline void Init(uint32_t nbins) {
|
||||||
|
if (nbins_ != nbins) {
|
||||||
|
data_arr_.clear();
|
||||||
nbins_ = nbins;
|
nbins_ = nbins;
|
||||||
row_ptr_.clear();
|
}
|
||||||
data_.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// create an empty histogram for i-th node
|
// create an empty histogram for i-th node
|
||||||
void AddHistRow(bst_uint nid) {
|
inline void AddHistRow(bst_uint nid) {
|
||||||
constexpr uint32_t kMax = std::numeric_limits<uint32_t>::max();
|
if (data_arr_.size() <= nid) {
|
||||||
if (nid >= row_ptr_.size()) {
|
size_t prev = data_arr_.size();
|
||||||
row_ptr_.resize(nid + 1, kMax);
|
data_arr_.resize(nid + 1);
|
||||||
}
|
|
||||||
CHECK_EQ(row_ptr_[nid], kMax);
|
|
||||||
|
|
||||||
row_ptr_[nid] = data_.size();
|
for (size_t i = prev; i < data_arr_.size(); ++i) {
|
||||||
data_.resize(data_.size() + nbins_);
|
data_arr_[i].resize(nbins_);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/*! \brief number of all bins over all features */
|
/*! \brief number of all bins over all features */
|
||||||
uint32_t nbins_;
|
uint32_t nbins_ = 0;
|
||||||
|
std::vector<std::vector<GradStatHist>> data_arr_;
|
||||||
std::vector<tree::GradStats> data_;
|
|
||||||
|
|
||||||
/*! \brief row_ptr_[nid] locates bin for historgram of node nid */
|
|
||||||
std::vector<size_t> row_ptr_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief builder for histograms of gradient statistics
|
* \brief builder for histograms of gradient statistics
|
||||||
*/
|
*/
|
||||||
@ -267,21 +355,55 @@ class GHistBuilder {
|
|||||||
inline void Init(size_t nthread, uint32_t nbins) {
|
inline void Init(size_t nthread, uint32_t nbins) {
|
||||||
nthread_ = nthread;
|
nthread_ = nthread;
|
||||||
nbins_ = nbins;
|
nbins_ = nbins;
|
||||||
thread_init_.resize(nthread_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// construct a histogram via histogram aggregation
|
|
||||||
void BuildHist(const std::vector<GradientPair>& gpair,
|
|
||||||
const RowSetCollection::Elem row_indices,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
GHistRow hist);
|
|
||||||
// same, with feature grouping
|
|
||||||
void BuildBlockHist(const std::vector<GradientPair>& gpair,
|
void BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||||
const RowSetCollection::Elem row_indices,
|
const RowSetCollection::Elem row_indices,
|
||||||
const GHistIndexBlockMatrix& gmatb,
|
const GHistIndexBlockMatrix& gmatb,
|
||||||
GHistRow hist);
|
GHistRow hist) {
|
||||||
// construct a histogram via subtraction trick
|
constexpr int kUnroll = 8; // loop unrolling factor
|
||||||
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent);
|
const int32_t nblock = gmatb.GetNumBlock();
|
||||||
|
const size_t nrows = row_indices.end - row_indices.begin;
|
||||||
|
const size_t rest = nrows % kUnroll;
|
||||||
|
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (int32_t bid = 0; bid < nblock; ++bid) {
|
||||||
|
auto gmat = gmatb[bid];
|
||||||
|
|
||||||
|
for (size_t i = 0; i < nrows - rest; i += kUnroll) {
|
||||||
|
size_t rid[kUnroll];
|
||||||
|
size_t ibegin[kUnroll];
|
||||||
|
size_t iend[kUnroll];
|
||||||
|
GradientPair stat[kUnroll];
|
||||||
|
for (int k = 0; k < kUnroll; ++k) {
|
||||||
|
rid[k] = row_indices.begin[i + k];
|
||||||
|
}
|
||||||
|
for (int k = 0; k < kUnroll; ++k) {
|
||||||
|
ibegin[k] = gmat.row_ptr[rid[k]];
|
||||||
|
iend[k] = gmat.row_ptr[rid[k] + 1];
|
||||||
|
}
|
||||||
|
for (int k = 0; k < kUnroll; ++k) {
|
||||||
|
stat[k] = gpair[rid[k]];
|
||||||
|
}
|
||||||
|
for (int k = 0; k < kUnroll; ++k) {
|
||||||
|
for (size_t j = ibegin[k]; j < iend[k]; ++j) {
|
||||||
|
const uint32_t bin = gmat.index[j];
|
||||||
|
hist[bin].Add(stat[k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (size_t i = nrows - rest; i < nrows; ++i) {
|
||||||
|
const size_t rid = row_indices.begin[i];
|
||||||
|
const size_t ibegin = gmat.row_ptr[rid];
|
||||||
|
const size_t iend = gmat.row_ptr[rid + 1];
|
||||||
|
const GradientPair stat = gpair[rid];
|
||||||
|
for (size_t j = ibegin; j < iend; ++j) {
|
||||||
|
const uint32_t bin = gmat.index[j];
|
||||||
|
hist[bin].Add(stat);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t GetNumBins() {
|
uint32_t GetNumBins() {
|
||||||
return nbins_;
|
return nbins_;
|
||||||
@ -292,11 +414,19 @@ class GHistBuilder {
|
|||||||
size_t nthread_;
|
size_t nthread_;
|
||||||
/*! \brief number of all bins over all features */
|
/*! \brief number of all bins over all features */
|
||||||
uint32_t nbins_;
|
uint32_t nbins_;
|
||||||
std::vector<size_t> thread_init_;
|
|
||||||
std::vector<tree::GradStats> data_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
void BuildHistLocalDense(size_t istart, size_t iend, size_t nrows, const size_t* rid,
|
||||||
|
const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr,
|
||||||
|
GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat);
|
||||||
|
|
||||||
|
void BuildHistLocalSparse(size_t istart, size_t iend, size_t nrows, const size_t* rid,
|
||||||
|
const uint32_t* index, const GradientPair::ValueT* pgh, const size_t* row_ptr,
|
||||||
|
GradStatHist::GradType* data_local_hist, GradStatHist* grad_stat);
|
||||||
|
|
||||||
|
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent);
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_COMMON_HIST_UTIL_H_
|
#endif // XGBOOST_COMMON_HIST_UTIL_H_
|
||||||
|
|||||||
@ -27,10 +27,10 @@ class RowSetCollection {
|
|||||||
// id of node associated with this instance set; -1 means uninitialized
|
// id of node associated with this instance set; -1 means uninitialized
|
||||||
Elem()
|
Elem()
|
||||||
= default;
|
= default;
|
||||||
Elem(const size_t* begin,
|
Elem(const size_t* begin_,
|
||||||
const size_t* end,
|
const size_t* end_,
|
||||||
int node_id)
|
int node_id_)
|
||||||
: begin(begin), end(end), node_id(node_id) {}
|
: begin(begin_), end(end_), node_id(node_id_) {}
|
||||||
|
|
||||||
inline size_t Size() const {
|
inline size_t Size() const {
|
||||||
return end - begin;
|
return end - begin;
|
||||||
@ -42,6 +42,10 @@ class RowSetCollection {
|
|||||||
std::vector<size_t> right;
|
std::vector<size_t> right;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
size_t Size(unsigned node_id) {
|
||||||
|
return elem_of_each_node_[node_id].Size();
|
||||||
|
}
|
||||||
|
|
||||||
inline std::vector<Elem>::const_iterator begin() const { // NOLINT
|
inline std::vector<Elem>::const_iterator begin() const { // NOLINT
|
||||||
return elem_of_each_node_.begin();
|
return elem_of_each_node_.begin();
|
||||||
}
|
}
|
||||||
@ -51,12 +55,12 @@ class RowSetCollection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief return corresponding element set given the node_id */
|
/*! \brief return corresponding element set given the node_id */
|
||||||
inline const Elem& operator[](unsigned node_id) const {
|
inline Elem operator[](unsigned node_id) const {
|
||||||
const Elem& e = elem_of_each_node_[node_id];
|
const Elem e = elem_of_each_node_[node_id];
|
||||||
CHECK(e.begin != nullptr)
|
|
||||||
<< "access element that is not in the set";
|
|
||||||
return e;
|
return e;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// clear up things
|
// clear up things
|
||||||
inline void Clear() {
|
inline void Clear() {
|
||||||
elem_of_each_node_.clear();
|
elem_of_each_node_.clear();
|
||||||
@ -81,38 +85,29 @@ class RowSetCollection {
|
|||||||
const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
|
const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
|
||||||
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
|
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// split rowset into two
|
// split rowset into two
|
||||||
inline void AddSplit(unsigned node_id,
|
inline void AddSplit(unsigned node_id,
|
||||||
const std::vector<Split>& row_split_tloc,
|
size_t iLeft,
|
||||||
unsigned left_node_id,
|
unsigned left_node_id,
|
||||||
unsigned right_node_id) {
|
unsigned right_node_id) {
|
||||||
const Elem e = elem_of_each_node_[node_id];
|
Elem e = elem_of_each_node_[node_id];
|
||||||
const auto nthread = static_cast<bst_omp_uint>(row_split_tloc.size());
|
|
||||||
CHECK(e.begin != nullptr);
|
|
||||||
size_t* all_begin = dmlc::BeginPtr(row_indices_);
|
|
||||||
size_t* begin = all_begin + (e.begin - all_begin);
|
|
||||||
|
|
||||||
size_t* it = begin;
|
CHECK(e.begin != nullptr);
|
||||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
|
||||||
std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it);
|
size_t* begin = const_cast<size_t*>(e.begin);
|
||||||
it += row_split_tloc[tid].left.size();
|
size_t* split_pt = begin + iLeft;
|
||||||
}
|
|
||||||
size_t* split_pt = it;
|
|
||||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
|
||||||
std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it);
|
|
||||||
it += row_split_tloc[tid].right.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (left_node_id >= elem_of_each_node_.size()) {
|
if (left_node_id >= elem_of_each_node_.size()) {
|
||||||
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
|
elem_of_each_node_.resize((left_node_id + 1)*2, Elem(nullptr, nullptr, -1));
|
||||||
}
|
}
|
||||||
if (right_node_id >= elem_of_each_node_.size()) {
|
if (right_node_id >= elem_of_each_node_.size()) {
|
||||||
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
|
elem_of_each_node_.resize((right_node_id + 1)*2, Elem(nullptr, nullptr, -1));
|
||||||
}
|
}
|
||||||
|
|
||||||
elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id);
|
elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id);
|
||||||
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id);
|
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id);
|
||||||
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
|
elem_of_each_node_[node_id] = Elem(begin, e.end, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// stores the row indices in the set
|
// stores the row indices in the set
|
||||||
|
|||||||
@ -291,7 +291,7 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
T w = CalcWeight(p, sum_grad, sum_hess);
|
T w = CalcWeight(p, sum_grad, sum_hess);
|
||||||
T ret = CalcGainGivenWeight(p, sum_grad, sum_hess, w);
|
T ret = CalcGainGivenWeight<TrainingParams, T>(p, sum_grad, sum_hess, w);
|
||||||
if (p.reg_alpha == 0.0f) {
|
if (p.reg_alpha == 0.0f) {
|
||||||
return ret;
|
return ret;
|
||||||
} else {
|
} else {
|
||||||
@ -311,7 +311,7 @@ template <typename TrainingParams, typename T>
|
|||||||
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
|
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
|
||||||
T test_grad, T test_hess) {
|
T test_grad, T test_hess) {
|
||||||
T w = CalcWeight(sum_grad, sum_hess);
|
T w = CalcWeight(sum_grad, sum_hess);
|
||||||
T ret = CalcGainGivenWeight(p, test_grad, test_hess);
|
T ret = CalcGainGivenWeight<TrainingParams, T>(p, test_grad, test_hess);
|
||||||
if (p.reg_alpha == 0.0f) {
|
if (p.reg_alpha == 0.0f) {
|
||||||
return ret;
|
return ret;
|
||||||
} else {
|
} else {
|
||||||
@ -350,15 +350,16 @@ XGBOOST_DEVICE inline float CalcWeight(const TrainingParams &p, GpairT sum_grad)
|
|||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief core statistics used for tree construction */
|
/*! \brief core statistics used for tree construction */
|
||||||
struct XGBOOST_ALIGNAS(16) GradStats {
|
struct GradStats {
|
||||||
|
typedef double GradType;
|
||||||
/*! \brief sum gradient statistics */
|
/*! \brief sum gradient statistics */
|
||||||
double sum_grad;
|
GradType sum_grad;
|
||||||
/*! \brief sum hessian statistics */
|
/*! \brief sum hessian statistics */
|
||||||
double sum_hess;
|
GradType sum_hess;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
XGBOOST_DEVICE double GetGrad() const { return sum_grad; }
|
XGBOOST_DEVICE GradType GetGrad() const { return sum_grad; }
|
||||||
XGBOOST_DEVICE double GetHess() const { return sum_hess; }
|
XGBOOST_DEVICE GradType GetHess() const { return sum_hess; }
|
||||||
|
|
||||||
XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} {
|
XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} {
|
||||||
static_assert(sizeof(GradStats) == 16,
|
static_assert(sizeof(GradStats) == 16,
|
||||||
@ -368,7 +369,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
|||||||
template <typename GpairT>
|
template <typename GpairT>
|
||||||
XGBOOST_DEVICE explicit GradStats(const GpairT &sum)
|
XGBOOST_DEVICE explicit GradStats(const GpairT &sum)
|
||||||
: sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {}
|
: sum_grad(sum.GetGrad()), sum_hess(sum.GetHess()) {}
|
||||||
explicit GradStats(const double grad, const double hess)
|
explicit GradStats(const GradType grad, const GradType hess)
|
||||||
: sum_grad(grad), sum_hess(hess) {}
|
: sum_grad(grad), sum_hess(hess) {}
|
||||||
/*!
|
/*!
|
||||||
* \brief accumulate statistics
|
* \brief accumulate statistics
|
||||||
@ -393,7 +394,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
|||||||
/*! \return whether the statistics is not used yet */
|
/*! \return whether the statistics is not used yet */
|
||||||
inline bool Empty() const { return sum_hess == 0.0; }
|
inline bool Empty() const { return sum_hess == 0.0; }
|
||||||
/*! \brief add statistics to the data */
|
/*! \brief add statistics to the data */
|
||||||
inline void Add(double grad, double hess) {
|
inline void Add(GradType grad, GradType hess) {
|
||||||
sum_grad += grad;
|
sum_grad += grad;
|
||||||
sum_hess += hess;
|
sum_hess += hess;
|
||||||
}
|
}
|
||||||
@ -423,7 +424,7 @@ struct ValueConstraint {
|
|||||||
|
|
||||||
template <typename ParamT>
|
template <typename ParamT>
|
||||||
XGBOOST_DEVICE inline double CalcGain(const ParamT ¶m, GradStats stats) const {
|
XGBOOST_DEVICE inline double CalcGain(const ParamT ¶m, GradStats stats) const {
|
||||||
return CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess,
|
return CalcGainGivenWeight<ParamT, float>(param, stats.sum_grad, stats.sum_hess,
|
||||||
CalcWeight(param, stats));
|
CalcWeight(param, stats));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -434,8 +435,8 @@ struct ValueConstraint {
|
|||||||
double wleft = CalcWeight(param, left);
|
double wleft = CalcWeight(param, left);
|
||||||
double wright = CalcWeight(param, right);
|
double wright = CalcWeight(param, right);
|
||||||
double gain =
|
double gain =
|
||||||
CalcGainGivenWeight(param, left.sum_grad, left.sum_hess, wleft) +
|
CalcGainGivenWeight<ParamT, float>(param, left.sum_grad, left.sum_hess, wleft) +
|
||||||
CalcGainGivenWeight(param, right.sum_grad, right.sum_hess, wright);
|
CalcGainGivenWeight<ParamT, float>(param, right.sum_grad, right.sum_hess, wright);
|
||||||
if (constraint == 0) {
|
if (constraint == 0) {
|
||||||
return gain;
|
return gain;
|
||||||
} else if (constraint > 0) {
|
} else if (constraint > 0) {
|
||||||
@ -480,6 +481,7 @@ struct SplitEntry {
|
|||||||
bst_float split_value{0.0f};
|
bst_float split_value{0.0f};
|
||||||
GradStats left_sum;
|
GradStats left_sum;
|
||||||
GradStats right_sum;
|
GradStats right_sum;
|
||||||
|
bool default_left{true};
|
||||||
|
|
||||||
/*! \brief constructor */
|
/*! \brief constructor */
|
||||||
SplitEntry() = default;
|
SplitEntry() = default;
|
||||||
@ -494,7 +496,11 @@ struct SplitEntry {
|
|||||||
* \param split_index the feature index where the split is on
|
* \param split_index the feature index where the split is on
|
||||||
*/
|
*/
|
||||||
inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
|
inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
|
||||||
if (this->SplitIndex() <= split_index) {
|
if (!std::isfinite(new_loss_chg)) { // in some cases new_loss_chg can be NaN or Inf,
|
||||||
|
// for example when lambda = 0 & min_child_weight = 0
|
||||||
|
// skip value in this case
|
||||||
|
return false;
|
||||||
|
} else if (this->SplitIndex() <= split_index) {
|
||||||
return new_loss_chg > this->loss_chg;
|
return new_loss_chg > this->loss_chg;
|
||||||
} else {
|
} else {
|
||||||
return !(this->loss_chg > new_loss_chg);
|
return !(this->loss_chg > new_loss_chg);
|
||||||
@ -512,6 +518,7 @@ struct SplitEntry {
|
|||||||
this->split_value = e.split_value;
|
this->split_value = e.split_value;
|
||||||
this->left_sum = e.left_sum;
|
this->left_sum = e.left_sum;
|
||||||
this->right_sum = e.right_sum;
|
this->right_sum = e.right_sum;
|
||||||
|
this->default_left = e.default_left;
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
@ -526,13 +533,11 @@ struct SplitEntry {
|
|||||||
* \return whether the proposed split is better and can replace current split
|
* \return whether the proposed split is better and can replace current split
|
||||||
*/
|
*/
|
||||||
inline bool Update(bst_float new_loss_chg, unsigned split_index,
|
inline bool Update(bst_float new_loss_chg, unsigned split_index,
|
||||||
bst_float new_split_value, bool default_left,
|
bst_float new_split_value, bool new_default_left,
|
||||||
const GradStats &left_sum, const GradStats &right_sum) {
|
const GradStats &left_sum, const GradStats &right_sum) {
|
||||||
if (this->NeedReplace(new_loss_chg, split_index)) {
|
if (this->NeedReplace(new_loss_chg, split_index)) {
|
||||||
this->loss_chg = new_loss_chg;
|
this->loss_chg = new_loss_chg;
|
||||||
if (default_left) {
|
this->default_left = new_default_left;
|
||||||
split_index |= (1U << 31);
|
|
||||||
}
|
|
||||||
this->sindex = split_index;
|
this->sindex = split_index;
|
||||||
this->split_value = new_split_value;
|
this->split_value = new_split_value;
|
||||||
this->left_sum = left_sum;
|
this->left_sum = left_sum;
|
||||||
@ -548,9 +553,9 @@ struct SplitEntry {
|
|||||||
dst.Update(src);
|
dst.Update(src);
|
||||||
}
|
}
|
||||||
/*!\return feature index to split on */
|
/*!\return feature index to split on */
|
||||||
inline unsigned SplitIndex() const { return sindex & ((1U << 31) - 1U); }
|
inline unsigned SplitIndex() const { return sindex; }
|
||||||
/*!\return whether missing value goes to left branch */
|
/*!\return whether missing value goes to left branch */
|
||||||
inline bool DefaultLeft() const { return (sindex >> 31) != 0; }
|
inline bool DefaultLeft() const { return default_left; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
@ -283,7 +283,9 @@ class MonotonicConstraint final : public SplitEvaluator {
|
|||||||
bst_float leftweight,
|
bst_float leftweight,
|
||||||
bst_float rightweight) override {
|
bst_float rightweight) override {
|
||||||
inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight);
|
inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight);
|
||||||
bst_uint newsize = std::max(leftid, rightid) + 1;
|
|
||||||
|
bst_uint newsize = std::max(bst_uint(lower_.size()), bst_uint(std::max(leftid, rightid) + 1u));
|
||||||
|
|
||||||
lower_.resize(newsize);
|
lower_.resize(newsize);
|
||||||
upper_.resize(newsize);
|
upper_.resize(newsize);
|
||||||
bst_int constraint = GetConstraint(featureid);
|
bst_int constraint = GetConstraint(featureid);
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,8 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017-2018 by Contributors
|
* Copyright 2017-2019 by Contributors
|
||||||
* \file updater_quantile_hist.h
|
* \file updater_quantile_hist.h
|
||||||
* \brief use quantized feature values to construct a tree
|
* \brief use quantized feature values to construct a tree
|
||||||
* \author Philip Cho, Tianqi Chen
|
* \author Philip Cho, Tianqi Chen, Egor Smirnov
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
#ifndef XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
#define XGBOOST_TREE_UPDATER_QUANTILE_HIST_H_
|
||||||
@ -18,51 +18,19 @@
|
|||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "./split_evaluator.h"
|
#include "./split_evaluator.h"
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
#include "../common/timer.h"
|
|
||||||
#include "../common/hist_util.h"
|
#include "../common/hist_util.h"
|
||||||
#include "../common/row_set.h"
|
#include "../common/row_set.h"
|
||||||
#include "../common/column_matrix.h"
|
#include "../common/column_matrix.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
namespace common {
|
||||||
/*!
|
struct GradStatHist;
|
||||||
* \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated.
|
}
|
||||||
*/
|
|
||||||
template<typename T, size_t MaxStackSize>
|
|
||||||
class MemStackAllocator {
|
|
||||||
public:
|
|
||||||
explicit MemStackAllocator(size_t required_size): required_size_(required_size) {
|
|
||||||
}
|
|
||||||
|
|
||||||
T* Get() {
|
|
||||||
if (!ptr_) {
|
|
||||||
if (MaxStackSize >= required_size_) {
|
|
||||||
ptr_ = stack_mem_;
|
|
||||||
} else {
|
|
||||||
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
|
|
||||||
do_free_ = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ptr_;
|
|
||||||
}
|
|
||||||
|
|
||||||
~MemStackAllocator() {
|
|
||||||
if (do_free_) free(ptr_);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private:
|
|
||||||
T* ptr_ = nullptr;
|
|
||||||
bool do_free_ = false;
|
|
||||||
size_t required_size_;
|
|
||||||
T stack_mem_[MaxStackSize];
|
|
||||||
};
|
|
||||||
|
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
using xgboost::common::HistCutMatrix;
|
using xgboost::common::HistCutMatrix;
|
||||||
@ -88,6 +56,7 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
bool UpdatePredictionCache(const DMatrix* data,
|
bool UpdatePredictionCache(const DMatrix* data,
|
||||||
HostDeviceVector<bst_float>* out_preds) override;
|
HostDeviceVector<bst_float>* out_preds) override;
|
||||||
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// training parameter
|
// training parameter
|
||||||
TrainParam param_;
|
TrainParam param_;
|
||||||
@ -100,6 +69,7 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
bool is_gmat_initialized_;
|
bool is_gmat_initialized_;
|
||||||
|
|
||||||
// data structure
|
// data structure
|
||||||
|
public:
|
||||||
struct NodeEntry {
|
struct NodeEntry {
|
||||||
/*! \brief statics for node entry */
|
/*! \brief statics for node entry */
|
||||||
GradStats stats;
|
GradStats stats;
|
||||||
@ -111,7 +81,8 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
SplitEntry best;
|
SplitEntry best;
|
||||||
// constructor
|
// constructor
|
||||||
explicit NodeEntry(const TrainParam& param)
|
explicit NodeEntry(const TrainParam& param)
|
||||||
: root_gain(0.0f), weight(0.0f) {}
|
: root_gain(0.0f), weight(0.0f) {
|
||||||
|
}
|
||||||
};
|
};
|
||||||
// actual builder that runs the algorithm
|
// actual builder that runs the algorithm
|
||||||
|
|
||||||
@ -121,11 +92,8 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
explicit Builder(const TrainParam& param,
|
explicit Builder(const TrainParam& param,
|
||||||
std::unique_ptr<TreeUpdater> pruner,
|
std::unique_ptr<TreeUpdater> pruner,
|
||||||
std::unique_ptr<SplitEvaluator> spliteval)
|
std::unique_ptr<SplitEvaluator> spliteval)
|
||||||
: param_(param), pruner_(std::move(pruner)),
|
: param_(param), pruner_(std::move(pruner)), spliteval_(std::move(spliteval)),
|
||||||
spliteval_(std::move(spliteval)), p_last_tree_(nullptr),
|
p_last_tree_(nullptr), p_last_fmat_(nullptr) { }
|
||||||
p_last_fmat_(nullptr) {
|
|
||||||
builder_monitor_.Init("Quantile::Builder");
|
|
||||||
}
|
|
||||||
// update one tree, growing
|
// update one tree, growing
|
||||||
virtual void Update(const GHistIndexMatrix& gmat,
|
virtual void Update(const GHistIndexMatrix& gmat,
|
||||||
const GHistIndexBlockMatrix& gmatb,
|
const GHistIndexBlockMatrix& gmatb,
|
||||||
@ -134,42 +102,104 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
DMatrix* p_fmat,
|
DMatrix* p_fmat,
|
||||||
RegTree* p_tree);
|
RegTree* p_tree);
|
||||||
|
|
||||||
inline void BuildHist(const std::vector<GradientPair>& gpair,
|
|
||||||
const RowSetCollection::Elem row_indices,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
const GHistIndexBlockMatrix& gmatb,
|
|
||||||
GHistRow hist,
|
|
||||||
bool sync_hist) {
|
|
||||||
builder_monitor_.Start("BuildHist");
|
|
||||||
if (param_.enable_feature_grouping > 0) {
|
|
||||||
hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist);
|
|
||||||
} else {
|
|
||||||
hist_builder_.BuildHist(gpair, row_indices, gmat, hist);
|
|
||||||
}
|
|
||||||
if (sync_hist) {
|
|
||||||
this->histred_.Allreduce(hist.data(), hist_builder_.GetNumBins());
|
|
||||||
}
|
|
||||||
builder_monitor_.Stop("BuildHist");
|
|
||||||
}
|
|
||||||
|
|
||||||
inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
|
||||||
builder_monitor_.Start("SubtractionTrick");
|
|
||||||
hist_builder_.SubtractionTrick(self, sibling, parent);
|
|
||||||
builder_monitor_.Stop("SubtractionTrick");
|
|
||||||
}
|
|
||||||
|
|
||||||
bool UpdatePredictionCache(const DMatrix* data,
|
bool UpdatePredictionCache(const DMatrix* data,
|
||||||
HostDeviceVector<bst_float>* p_out_preds);
|
HostDeviceVector<bst_float>* p_out_preds);
|
||||||
|
|
||||||
|
std::tuple<common::GradStatHist::GradType*, common::GradStatHist*>
|
||||||
|
GetHistBuffer(std::vector<uint8_t>* hist_is_init,
|
||||||
|
std::vector<common::GradStatHist>* grad_stats, size_t block_id, size_t nthread,
|
||||||
|
size_t tid, std::vector<common::GradStatHist::GradType*>* data_hist, size_t hist_size);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/* tree growing policies */
|
/* tree growing policies */
|
||||||
struct ExpandEntry {
|
struct ExpandEntry {
|
||||||
int nid;
|
int nid;
|
||||||
|
int sibling_nid;
|
||||||
|
int parent_nid;
|
||||||
int depth;
|
int depth;
|
||||||
bst_float loss_chg;
|
bst_float loss_chg;
|
||||||
unsigned timestamp;
|
unsigned timestamp;
|
||||||
ExpandEntry(int nid, int depth, bst_float loss_chg, unsigned tstmp)
|
ExpandEntry(int nid, int sibling_nid, int parent_nid, int depth, bst_float loss_chg,
|
||||||
: nid(nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {}
|
unsigned tstmp) : nid(nid), sibling_nid(sibling_nid), parent_nid(parent_nid),
|
||||||
|
depth(depth), loss_chg(loss_chg), timestamp(tstmp) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TreeGrowingPerfMonitor {
|
||||||
|
enum timer_name {INIT_DATA, INIT_NEW_NODE, BUILD_HIST, EVALUATE_SPLIT, APPLY_SPLIT};
|
||||||
|
|
||||||
|
double global_start;
|
||||||
|
|
||||||
|
// performance counters
|
||||||
|
double tstart;
|
||||||
|
double time_init_data = 0;
|
||||||
|
double time_init_new_node = 0;
|
||||||
|
double time_build_hist = 0;
|
||||||
|
double time_evaluate_split = 0;
|
||||||
|
double time_apply_split = 0;
|
||||||
|
|
||||||
|
inline void StartPerfMonitor() {
|
||||||
|
global_start = dmlc::GetTime();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void EndPerfMonitor() {
|
||||||
|
CHECK_GT(global_start, 0);
|
||||||
|
double total_time = dmlc::GetTime() - global_start;
|
||||||
|
LOG(INFO) << "\nInitData: "
|
||||||
|
<< std::fixed << std::setw(6) << std::setprecision(4) << time_init_data
|
||||||
|
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||||
|
<< time_init_data / total_time * 100 << "%)\n"
|
||||||
|
<< "InitNewNode: "
|
||||||
|
<< std::fixed << std::setw(6) << std::setprecision(4) << time_init_new_node
|
||||||
|
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||||
|
<< time_init_new_node / total_time * 100 << "%)\n"
|
||||||
|
<< "BuildHist: "
|
||||||
|
<< std::fixed << std::setw(6) << std::setprecision(4) << time_build_hist
|
||||||
|
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||||
|
<< time_build_hist / total_time * 100 << "%)\n"
|
||||||
|
<< "EvaluateSplit: "
|
||||||
|
<< std::fixed << std::setw(6) << std::setprecision(4) << time_evaluate_split
|
||||||
|
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||||
|
<< time_evaluate_split / total_time * 100 << "%)\n"
|
||||||
|
<< "ApplySplit: "
|
||||||
|
<< std::fixed << std::setw(6) << std::setprecision(4) << time_apply_split
|
||||||
|
<< " (" << std::fixed << std::setw(5) << std::setprecision(2)
|
||||||
|
<< time_apply_split / total_time * 100 << "%)\n"
|
||||||
|
<< "========================================\n"
|
||||||
|
<< "Total: "
|
||||||
|
<< std::fixed << std::setw(6) << std::setprecision(4) << total_time << std::endl;
|
||||||
|
// clear performance counters
|
||||||
|
time_init_data = 0;
|
||||||
|
time_init_new_node = 0;
|
||||||
|
time_build_hist = 0;
|
||||||
|
time_evaluate_split = 0;
|
||||||
|
time_apply_split = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void TickStart() {
|
||||||
|
tstart = dmlc::GetTime();
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void UpdatePerfTimer(const timer_name &timer_name) {
|
||||||
|
// CHECK_GT(tstart, 0); // TODO Fix
|
||||||
|
switch (timer_name) {
|
||||||
|
case INIT_DATA:
|
||||||
|
time_init_data += dmlc::GetTime() - tstart;
|
||||||
|
break;
|
||||||
|
case INIT_NEW_NODE:
|
||||||
|
time_init_new_node += dmlc::GetTime() - tstart;
|
||||||
|
break;
|
||||||
|
case BUILD_HIST:
|
||||||
|
time_build_hist += dmlc::GetTime() - tstart;
|
||||||
|
break;
|
||||||
|
case EVALUATE_SPLIT:
|
||||||
|
time_evaluate_split += dmlc::GetTime() - tstart;
|
||||||
|
break;
|
||||||
|
case APPLY_SPLIT:
|
||||||
|
time_apply_split += dmlc::GetTime() - tstart;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
tstart = -1;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// initialize temp data structure
|
// initialize temp data structure
|
||||||
@ -178,43 +208,16 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
const DMatrix& fmat,
|
const DMatrix& fmat,
|
||||||
const RegTree& tree);
|
const RegTree& tree);
|
||||||
|
|
||||||
void EvaluateSplit(const int nid,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
const HistCollection& hist,
|
|
||||||
const DMatrix& fmat,
|
|
||||||
const RegTree& tree);
|
|
||||||
|
|
||||||
void ApplySplit(int nid,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
const ColumnMatrix& column_matrix,
|
|
||||||
const HistCollection& hist,
|
|
||||||
const DMatrix& fmat,
|
|
||||||
RegTree* p_tree);
|
|
||||||
|
|
||||||
void ApplySplitDenseData(const RowSetCollection::Elem rowset,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
|
||||||
const Column& column,
|
|
||||||
bst_int split_cond,
|
|
||||||
bool default_left);
|
|
||||||
|
|
||||||
void ApplySplitSparseData(const RowSetCollection::Elem rowset,
|
|
||||||
const GHistIndexMatrix& gmat,
|
|
||||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
|
||||||
const Column& column,
|
|
||||||
bst_uint lower_bound,
|
|
||||||
bst_uint upper_bound,
|
|
||||||
bst_int split_cond,
|
|
||||||
bool default_left);
|
|
||||||
|
|
||||||
void InitNewNode(int nid,
|
void InitNewNode(int nid,
|
||||||
const GHistIndexMatrix& gmat,
|
const GHistIndexMatrix& gmat,
|
||||||
const std::vector<GradientPair>& gpair,
|
const std::vector<GradientPair>& gpair,
|
||||||
const DMatrix& fmat,
|
const DMatrix& fmat,
|
||||||
const RegTree& tree);
|
RegTree* tree,
|
||||||
|
QuantileHistMaker::NodeEntry* snode,
|
||||||
|
int32_t parentid);
|
||||||
|
|
||||||
// enumerate the split values of specific feature
|
// enumerate the split values of specific feature
|
||||||
void EnumerateSplit(int d_step,
|
bool EnumerateSplit(int d_step,
|
||||||
const GHistIndexMatrix& gmat,
|
const GHistIndexMatrix& gmat,
|
||||||
const GHistRow& hist,
|
const GHistRow& hist,
|
||||||
const NodeEntry& snode,
|
const NodeEntry& snode,
|
||||||
@ -223,37 +226,36 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
bst_uint fid,
|
bst_uint fid,
|
||||||
bst_uint nodeID);
|
bst_uint nodeID);
|
||||||
|
|
||||||
void ExpandWithDepthWidth(const GHistIndexMatrix &gmat,
|
void EvaluateSplitsBatch(const std::vector<ExpandEntry>& nodes,
|
||||||
|
const GHistIndexMatrix& gmat,
|
||||||
|
const DMatrix& fmat,
|
||||||
|
const std::vector<std::vector<uint8_t>>& hist_is_init,
|
||||||
|
const std::vector<std::vector<common::GradStatHist::GradType*>>& hist_buffers);
|
||||||
|
|
||||||
|
void ReduceHistograms(
|
||||||
|
common::GradStatHist::GradType* hist_data,
|
||||||
|
common::GradStatHist::GradType* sibling_hist_data,
|
||||||
|
common::GradStatHist::GradType* parent_hist_data,
|
||||||
|
const size_t ibegin,
|
||||||
|
const size_t iend,
|
||||||
|
const size_t inode,
|
||||||
|
const std::vector<std::vector<uint8_t>>& hist_is_init,
|
||||||
|
const std::vector<std::vector<common::GradStatHist::GradType*>>& hist_buffers);
|
||||||
|
|
||||||
|
void SyncHistograms(
|
||||||
|
RegTree* p_tree,
|
||||||
|
const std::vector<ExpandEntry>& nodes,
|
||||||
|
std::vector<std::vector<common::GradStatHist::GradType*>>* hist_buffers,
|
||||||
|
std::vector<std::vector<uint8_t>>* hist_is_init,
|
||||||
|
const std::vector<std::vector<common::GradStatHist>>& grad_stats);
|
||||||
|
|
||||||
|
void ExpandWithDepthWise(const GHistIndexMatrix &gmat,
|
||||||
const GHistIndexBlockMatrix &gmatb,
|
const GHistIndexBlockMatrix &gmatb,
|
||||||
const ColumnMatrix &column_matrix,
|
const ColumnMatrix &column_matrix,
|
||||||
DMatrix *p_fmat,
|
DMatrix *p_fmat,
|
||||||
RegTree *p_tree,
|
RegTree *p_tree,
|
||||||
const std::vector<GradientPair> &gpair_h);
|
const std::vector<GradientPair> &gpair_h);
|
||||||
|
|
||||||
void BuildLocalHistograms(int *starting_index,
|
|
||||||
int *sync_count,
|
|
||||||
const GHistIndexMatrix &gmat,
|
|
||||||
const GHistIndexBlockMatrix &gmatb,
|
|
||||||
RegTree *p_tree,
|
|
||||||
const std::vector<GradientPair> &gpair_h);
|
|
||||||
|
|
||||||
void SyncHistograms(int starting_index,
|
|
||||||
int sync_count,
|
|
||||||
RegTree *p_tree);
|
|
||||||
|
|
||||||
void BuildNodeStats(const GHistIndexMatrix &gmat,
|
|
||||||
DMatrix *p_fmat,
|
|
||||||
RegTree *p_tree,
|
|
||||||
const std::vector<GradientPair> &gpair_h);
|
|
||||||
|
|
||||||
void EvaluateSplits(const GHistIndexMatrix &gmat,
|
|
||||||
const ColumnMatrix &column_matrix,
|
|
||||||
DMatrix *p_fmat,
|
|
||||||
RegTree *p_tree,
|
|
||||||
int *num_leaves,
|
|
||||||
int depth,
|
|
||||||
unsigned *timestamp,
|
|
||||||
std::vector<ExpandEntry> *temp_qexpand_depth);
|
|
||||||
|
|
||||||
void ExpandWithLossGuide(const GHistIndexMatrix& gmat,
|
void ExpandWithLossGuide(const GHistIndexMatrix& gmat,
|
||||||
const GHistIndexBlockMatrix& gmatb,
|
const GHistIndexBlockMatrix& gmatb,
|
||||||
@ -262,6 +264,62 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
RegTree* p_tree,
|
RegTree* p_tree,
|
||||||
const std::vector<GradientPair>& gpair_h);
|
const std::vector<GradientPair>& gpair_h);
|
||||||
|
|
||||||
|
|
||||||
|
void BuildHistsBatch(const std::vector<ExpandEntry>& nodes, RegTree* tree,
|
||||||
|
const GHistIndexMatrix &gmat, const std::vector<GradientPair>& gpair,
|
||||||
|
std::vector<std::vector<common::GradStatHist::GradType*>>* hist_buffers,
|
||||||
|
std::vector<std::vector<uint8_t>>* hist_is_init);
|
||||||
|
|
||||||
|
void BuildNodeStat(const GHistIndexMatrix &gmat,
|
||||||
|
DMatrix *p_fmat,
|
||||||
|
RegTree *p_tree,
|
||||||
|
const std::vector<GradientPair> &gpair_h,
|
||||||
|
int32_t nid);
|
||||||
|
|
||||||
|
void BuildNodeStatBatch(
|
||||||
|
const GHistIndexMatrix &gmat,
|
||||||
|
DMatrix *p_fmat,
|
||||||
|
RegTree *p_tree,
|
||||||
|
const std::vector<GradientPair> &gpair_h,
|
||||||
|
const std::vector<ExpandEntry>& nodes);
|
||||||
|
|
||||||
|
int32_t FindSplitCond(int32_t nid,
|
||||||
|
RegTree *p_tree,
|
||||||
|
const GHistIndexMatrix &gmat);
|
||||||
|
|
||||||
|
void CreateNewNodesBatch(
|
||||||
|
const std::vector<ExpandEntry>& nodes,
|
||||||
|
const GHistIndexMatrix &gmat,
|
||||||
|
const ColumnMatrix &column_matrix,
|
||||||
|
DMatrix *p_fmat,
|
||||||
|
RegTree *p_tree,
|
||||||
|
int *num_leaves,
|
||||||
|
int depth,
|
||||||
|
unsigned *timestamp,
|
||||||
|
std::vector<ExpandEntry> *temp_qexpand_depth);
|
||||||
|
|
||||||
|
template<typename TaskType, typename NodeType>
|
||||||
|
void CreateTasksForApplySplit(
|
||||||
|
const std::vector<ExpandEntry>& nodes,
|
||||||
|
const GHistIndexMatrix &gmat,
|
||||||
|
RegTree *p_tree,
|
||||||
|
int *num_leaves,
|
||||||
|
const int depth,
|
||||||
|
const size_t block_size,
|
||||||
|
std::vector<TaskType>* tasks,
|
||||||
|
std::vector<NodeType>* nodes_bounds);
|
||||||
|
|
||||||
|
void CreateTasksForBuildHist(
|
||||||
|
size_t block_size_rows,
|
||||||
|
size_t nthread,
|
||||||
|
const std::vector<ExpandEntry>& nodes,
|
||||||
|
std::vector<std::vector<common::GradStatHist::GradType*>>* hist_buffers,
|
||||||
|
std::vector<std::vector<uint8_t>>* hist_is_init,
|
||||||
|
std::vector<std::vector<common::GradStatHist>>* grad_stats,
|
||||||
|
std::vector<int32_t>* task_nid,
|
||||||
|
std::vector<int32_t>* task_node_idx,
|
||||||
|
std::vector<int32_t>* task_block_idx);
|
||||||
|
|
||||||
inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) {
|
inline static bool LossGuide(ExpandEntry lhs, ExpandEntry rhs) {
|
||||||
if (lhs.loss_chg == rhs.loss_chg) {
|
if (lhs.loss_chg == rhs.loss_chg) {
|
||||||
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
return lhs.timestamp > rhs.timestamp; // favor small timestamp
|
||||||
@ -270,6 +328,8 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HistCollection hist_buff_;
|
||||||
|
|
||||||
// --data fields--
|
// --data fields--
|
||||||
const TrainParam& param_;
|
const TrainParam& param_;
|
||||||
// number of omp thread used during training
|
// number of omp thread used during training
|
||||||
@ -280,6 +340,7 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
// the temp space for split
|
// the temp space for split
|
||||||
std::vector<RowSetCollection::Split> row_split_tloc_;
|
std::vector<RowSetCollection::Split> row_split_tloc_;
|
||||||
std::vector<SplitEntry> best_split_tloc_;
|
std::vector<SplitEntry> best_split_tloc_;
|
||||||
|
std::vector<size_t> buffer_for_partition_;
|
||||||
/*! \brief TreeNode Data: statistics for each constructed node */
|
/*! \brief TreeNode Data: statistics for each constructed node */
|
||||||
std::vector<NodeEntry> snode_;
|
std::vector<NodeEntry> snode_;
|
||||||
/*! \brief culmulative histogram of gradients. */
|
/*! \brief culmulative histogram of gradients. */
|
||||||
@ -311,8 +372,8 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
|
enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
|
||||||
DataLayout data_layout_;
|
DataLayout data_layout_;
|
||||||
|
|
||||||
common::Monitor builder_monitor_;
|
TreeGrowingPerfMonitor perf_monitor;
|
||||||
rabit::Reducer<GradStats, GradStats::Reduce> histred_;
|
rabit::Reducer<common::GradStatHist, common::GradStatHist::Reduce> histred_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<Builder> builder_;
|
std::unique_ptr<Builder> builder_;
|
||||||
|
|||||||
@ -101,8 +101,13 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
RealImpl::InitData(gmat, gpair, fmat, tree);
|
RealImpl::InitData(gmat, gpair, fmat, tree);
|
||||||
GHistIndexBlockMatrix dummy;
|
GHistIndexBlockMatrix dummy;
|
||||||
hist_.AddHistRow(nid);
|
hist_.AddHistRow(nid);
|
||||||
BuildHist(gpair, row_set_collection_[nid],
|
|
||||||
gmat, dummy, hist_[nid], false);
|
std::vector<std::vector<float*>> hist_buffers;
|
||||||
|
std::vector<std::vector<uint8_t>> hist_is_init;
|
||||||
|
std::vector<ExpandEntry> nodes = {ExpandEntry(nid, -1, -1, tree.GetDepth(0), 0.0, 0)};
|
||||||
|
BuildHistsBatch(nodes, const_cast<RegTree*>(&tree), gmat, gpair, &hist_buffers, &hist_is_init);
|
||||||
|
RealImpl::InitNewNode(nid, gmat, gpair, fmat, const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
|
||||||
|
EvaluateSplitsBatch(nodes, gmat, fmat, hist_is_init, hist_buffers);
|
||||||
|
|
||||||
// Check if number of histogram bins is correct
|
// Check if number of histogram bins is correct
|
||||||
ASSERT_EQ(hist_[nid].size(), gmat.cut.row_ptr.back());
|
ASSERT_EQ(hist_[nid].size(), gmat.cut.row_ptr.back());
|
||||||
@ -143,10 +148,12 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
RealImpl::InitData(gmat, row_gpairs, *(*dmat), tree);
|
RealImpl::InitData(gmat, row_gpairs, *(*dmat), tree);
|
||||||
hist_.AddHistRow(0);
|
hist_.AddHistRow(0);
|
||||||
|
|
||||||
BuildHist(row_gpairs, row_set_collection_[0],
|
std::vector<ExpandEntry> nodes = {ExpandEntry(0, -1, -1, tree.GetDepth(0), 0.0, 0)};
|
||||||
gmat, quantile_index_block, hist_[0], false);
|
std::vector<std::vector<float*>> hist_buffers;
|
||||||
|
std::vector<std::vector<uint8_t>> hist_is_init;
|
||||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), tree);
|
BuildHistsBatch(nodes, const_cast<RegTree*>(&tree), gmat, row_gpairs, &hist_buffers, &hist_is_init);
|
||||||
|
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), const_cast<RegTree*>(&tree), &snode_[0], tree[0].Parent());
|
||||||
|
EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers);
|
||||||
|
|
||||||
/* Compute correct split (best_split) using the computed histogram */
|
/* Compute correct split (best_split) using the computed histogram */
|
||||||
const size_t num_row = dmat->get()->Info().num_row_;
|
const size_t num_row = dmat->get()->Info().num_row_;
|
||||||
@ -197,6 +204,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
const auto split_gain
|
const auto split_gain
|
||||||
= evaluator->ComputeSplitScore(0, fid, GradStats(left_sum),
|
= evaluator->ComputeSplitScore(0, fid, GradStats(left_sum),
|
||||||
GradStats(right_sum));
|
GradStats(right_sum));
|
||||||
|
|
||||||
if (split_gain > best_split_gain) {
|
if (split_gain > best_split_gain) {
|
||||||
best_split_gain = split_gain;
|
best_split_gain = split_gain;
|
||||||
best_split_feature = fid;
|
best_split_feature = fid;
|
||||||
@ -206,7 +214,8 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* Now compare against result given by EvaluateSplit() */
|
/* Now compare against result given by EvaluateSplit() */
|
||||||
RealImpl::EvaluateSplit(0, gmat, hist_, *(*dmat), tree);
|
EvaluateSplitsBatch(nodes, gmat, **dmat, hist_is_init, hist_buffers);
|
||||||
|
|
||||||
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
|
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
|
||||||
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.cut[best_split_threshold]);
|
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.cut[best_split_threshold]);
|
||||||
|
|
||||||
@ -289,7 +298,7 @@ TEST(Updater, QuantileHist_EvalSplits) {
|
|||||||
std::vector<std::pair<std::string, std::string>> cfg
|
std::vector<std::pair<std::string, std::string>> cfg
|
||||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
|
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
|
||||||
{"split_evaluator", "elastic_net"},
|
{"split_evaluator", "elastic_net"},
|
||||||
{"reg_lambda", "0"}, {"reg_alpha", "0"}, {"max_delta_step", "0"},
|
{"reg_lambda", "1.0f"}, {"reg_alpha", "0"}, {"max_delta_step", "0"},
|
||||||
{"min_child_weight", "0"}};
|
{"min_child_weight", "0"}};
|
||||||
QuantileHistMock maker(cfg);
|
QuantileHistMock maker(cfg);
|
||||||
maker.TestEvaluateSplit();
|
maker.TestEvaluateSplit();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user