Optimized ApplySplit, BuildHist and UpdatePredictCache functions on CPU (#5244)
* Split up sparse and dense build hist kernels. * Add `PartitionBuilder`.
This commit is contained in:
parent
b81f8cbbc0
commit
1b97eaf7a7
@ -37,6 +37,7 @@ class Column {
|
||||
size_t Size() const { return len_; }
|
||||
uint32_t GetGlobalBinIdx(size_t idx) const { return index_base_ + index_[idx]; }
|
||||
uint32_t GetFeatureBinIdx(size_t idx) const { return index_[idx]; }
|
||||
common::Span<const uint32_t> GetFeatureBinIdxPtr() const { return { index_, len_ }; }
|
||||
// column.GetFeatureBinIdx(idx) + column.GetBaseIdx(idx) ==
|
||||
// column.GetGlobalBinIdx(idx)
|
||||
uint32_t GetBaseIdx() const { return index_base_; }
|
||||
@ -186,8 +187,8 @@ class ColumnMatrix {
|
||||
|
||||
std::vector<size_t> feature_counts_;
|
||||
std::vector<ColumnType> type_;
|
||||
SimpleArray<uint32_t> index_; // index_: may store smaller integers; needs padding
|
||||
SimpleArray<size_t> row_ind_;
|
||||
std::vector<uint32_t> index_; // index_: may store smaller integers; needs padding
|
||||
std::vector<size_t> row_ind_;
|
||||
std::vector<ColumnBoundary> boundary_;
|
||||
|
||||
// index_base_[fid]: least bin id for feature fid
|
||||
|
||||
@ -672,7 +672,7 @@ void GHistIndexBlockMatrix::Init(const GHistIndexMatrix& gmat,
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief fill a histogram by zeroes
|
||||
* \brief fill a histogram by zeros in range [begin, end)
|
||||
*/
|
||||
void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) {
|
||||
memset(hist.data() + begin, '\0', (end-begin)*sizeof(tree::GradStats));
|
||||
@ -719,40 +719,141 @@ void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2,
|
||||
}
|
||||
}
|
||||
|
||||
struct Prefetch {
|
||||
public:
|
||||
static constexpr size_t kCacheLineSize = 64;
|
||||
static constexpr size_t kPrefetchOffset = 10;
|
||||
static constexpr size_t kPrefetchStep =
|
||||
kCacheLineSize / sizeof(decltype(GHistIndexMatrix::index)::value_type);
|
||||
|
||||
private:
|
||||
static constexpr size_t kNoPrefetchSize =
|
||||
kPrefetchOffset + kCacheLineSize /
|
||||
sizeof(decltype(GHistIndexMatrix::row_ptr)::value_type);
|
||||
|
||||
public:
|
||||
static size_t NoPrefetchSize(size_t rows) {
|
||||
return std::min(rows, kNoPrefetchSize);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr size_t Prefetch::kNoPrefetchSize;
|
||||
|
||||
template<typename FPType, bool do_prefetch>
|
||||
void BuildHistDenseKernel(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const size_t n_features,
|
||||
GHistRow hist) {
|
||||
const size_t size = row_indices.Size();
|
||||
const size_t* rid = row_indices.begin;
|
||||
const float* pgh = reinterpret_cast<const float*>(gpair.data());
|
||||
const uint32_t* gradient_index = gmat.index.data();
|
||||
FPType* hist_data = reinterpret_cast<FPType*>(hist.data());
|
||||
|
||||
const uint32_t two {2}; // Each element from 'gpair' and 'hist' contains
|
||||
// 2 FP values: gradient and hessian.
|
||||
// So we need to multiply each row-index/bin-index by 2
|
||||
// to work with gradient pairs as a singe row FP array
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const size_t icol_start = rid[i] * n_features;
|
||||
const size_t idx_gh = two * rid[i];
|
||||
|
||||
if (do_prefetch) {
|
||||
const size_t icol_start_prefetch = rid[i + Prefetch::kPrefetchOffset] * n_features;
|
||||
|
||||
PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]);
|
||||
for (size_t j = icol_start_prefetch; j < icol_start_prefetch + n_features;
|
||||
j += Prefetch::kPrefetchStep) {
|
||||
PREFETCH_READ_T0(gradient_index + j);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t j = icol_start; j < icol_start + n_features; ++j) {
|
||||
const uint32_t idx_bin = two * gradient_index[j];
|
||||
|
||||
hist_data[idx_bin] += pgh[idx_gh];
|
||||
hist_data[idx_bin+1] += pgh[idx_gh+1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename FPType, bool do_prefetch>
|
||||
void BuildHistSparseKernel(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
GHistRow hist) {
|
||||
const size_t size = row_indices.Size();
|
||||
const size_t* rid = row_indices.begin;
|
||||
const float* pgh = reinterpret_cast<const float*>(gpair.data());
|
||||
const uint32_t* gradient_index = gmat.index.data();
|
||||
const size_t* row_ptr = gmat.row_ptr.data();
|
||||
FPType* hist_data = reinterpret_cast<FPType*>(hist.data());
|
||||
|
||||
const uint32_t two {2}; // Each element from 'gpair' and 'hist' contains
|
||||
// 2 FP values: gradient and hessian.
|
||||
// So we need to multiply each row-index/bin-index by 2
|
||||
// to work with gradient pairs as a singe row FP array
|
||||
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
const size_t icol_start = row_ptr[rid[i]];
|
||||
const size_t icol_end = row_ptr[rid[i]+1];
|
||||
const size_t idx_gh = two * rid[i];
|
||||
|
||||
if (do_prefetch) {
|
||||
const size_t icol_start_prftch = row_ptr[rid[i+Prefetch::kPrefetchOffset]];
|
||||
const size_t icol_end_prefect = row_ptr[rid[i+Prefetch::kPrefetchOffset]+1];
|
||||
|
||||
PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]);
|
||||
for (size_t j = icol_start_prftch; j < icol_end_prefect; j+=Prefetch::kPrefetchStep) {
|
||||
PREFETCH_READ_T0(gradient_index + j);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t j = icol_start; j < icol_end; ++j) {
|
||||
const uint32_t idx_bin = two * gradient_index[j];
|
||||
hist_data[idx_bin] += pgh[idx_gh];
|
||||
hist_data[idx_bin+1] += pgh[idx_gh+1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename FPType, bool do_prefetch>
|
||||
void BuildHistKernel(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat, const bool isDense, GHistRow hist) {
|
||||
if (row_indices.Size() && isDense) {
|
||||
const size_t* row_ptr = gmat.row_ptr.data();
|
||||
const size_t n_features = row_ptr[row_indices.begin[0]+1] - row_ptr[row_indices.begin[0]];
|
||||
BuildHistDenseKernel<FPType, do_prefetch>(gpair, row_indices, gmat, n_features, hist);
|
||||
} else {
|
||||
BuildHistSparseKernel<FPType, do_prefetch>(gpair, row_indices, gmat, hist);
|
||||
}
|
||||
}
|
||||
|
||||
void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
GHistRow hist) {
|
||||
const size_t* rid = row_indices.begin;
|
||||
GHistRow hist,
|
||||
bool isDense) {
|
||||
using FPType = decltype(tree::GradStats::sum_grad);
|
||||
const size_t nrows = row_indices.Size();
|
||||
const uint32_t* index = gmat.index.data();
|
||||
const size_t* row_ptr = gmat.row_ptr.data();
|
||||
const float* pgh = reinterpret_cast<const float*>(gpair.data());
|
||||
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);
|
||||
|
||||
double* hist_data = reinterpret_cast<double*>(hist.data());
|
||||
// if need to work with all rows from bin-matrix (e.g. root node)
|
||||
const bool contiguousBlock = (row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1);
|
||||
|
||||
const size_t cache_line_size = 64;
|
||||
const size_t prefetch_offset = 10;
|
||||
size_t no_prefetch_size = prefetch_offset + cache_line_size/sizeof(*rid);
|
||||
no_prefetch_size = no_prefetch_size > nrows ? nrows : no_prefetch_size;
|
||||
if (contiguousBlock) {
|
||||
// contiguous memory access, built-in HW prefetching is enough
|
||||
BuildHistKernel<FPType, false>(gpair, row_indices, gmat, isDense, hist);
|
||||
} else {
|
||||
const RowSetCollection::Elem span1(row_indices.begin, row_indices.end - no_prefetch_size);
|
||||
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, row_indices.end);
|
||||
|
||||
for (size_t i = 0; i < nrows; ++i) {
|
||||
const size_t icol_start = row_ptr[rid[i]];
|
||||
const size_t icol_end = row_ptr[rid[i]+1];
|
||||
|
||||
if (i < nrows - no_prefetch_size) {
|
||||
PREFETCH_READ_T0(row_ptr + rid[i + prefetch_offset]);
|
||||
PREFETCH_READ_T0(pgh + 2*rid[i + prefetch_offset]);
|
||||
}
|
||||
|
||||
for (size_t j = icol_start; j < icol_end; ++j) {
|
||||
const uint32_t idx_bin = 2*index[j];
|
||||
const size_t idx_gh = 2*rid[i];
|
||||
|
||||
hist_data[idx_bin] += pgh[idx_gh];
|
||||
hist_data[idx_bin+1] += pgh[idx_gh+1];
|
||||
}
|
||||
BuildHistKernel<FPType, true>(gpair, span1, gmat, isDense, hist);
|
||||
// no prefetching to avoid loading extra memory
|
||||
BuildHistKernel<FPType, false>(gpair, span2, gmat, isDense, hist);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017 by Contributors
|
||||
* Copyright 2017-2020 by Contributors
|
||||
* \file hist_util.h
|
||||
* \brief Utility for fast histogram aggregation
|
||||
* \author Philip Cho, Tianqi Chen
|
||||
@ -25,75 +25,6 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
/*
|
||||
* \brief A thin wrapper around dynamically allocated C-style array.
|
||||
* Make sure to call resize() before use.
|
||||
*/
|
||||
template<typename T>
|
||||
struct SimpleArray {
|
||||
~SimpleArray() {
|
||||
std::free(ptr_);
|
||||
ptr_ = nullptr;
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
T* ptr = static_cast<T*>(std::malloc(n * sizeof(T)));
|
||||
CHECK(ptr) << "Failed to allocate memory";
|
||||
if (ptr_) {
|
||||
std::memcpy(ptr, ptr_, n_ * sizeof(T));
|
||||
std::free(ptr_);
|
||||
}
|
||||
ptr_ = ptr;
|
||||
n_ = n;
|
||||
}
|
||||
|
||||
T& operator[](size_t idx) {
|
||||
return ptr_[idx];
|
||||
}
|
||||
|
||||
T& operator[](size_t idx) const {
|
||||
return ptr_[idx];
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
return n_;
|
||||
}
|
||||
|
||||
T back() const {
|
||||
return ptr_[n_-1];
|
||||
}
|
||||
|
||||
T* data() {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
const T* data() const {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
|
||||
T* begin() {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
const T* begin() const {
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
T* end() {
|
||||
return ptr_ + n_;
|
||||
}
|
||||
|
||||
const T* end() const {
|
||||
return ptr_ + n_;
|
||||
}
|
||||
|
||||
private:
|
||||
T* ptr_ = nullptr;
|
||||
size_t n_ = 0;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief A single row in global histogram index.
|
||||
* Directly represent the global index in the histogram entry.
|
||||
@ -161,7 +92,7 @@ class HistogramCuts {
|
||||
return idx;
|
||||
}
|
||||
|
||||
BinIdx SearchBin(Entry const& e) {
|
||||
BinIdx SearchBin(Entry const& e) const {
|
||||
return SearchBin(e.fvalue, e.index);
|
||||
}
|
||||
};
|
||||
@ -261,8 +192,9 @@ size_t DeviceSketch(int device,
|
||||
|
||||
/*!
|
||||
* \brief preprocessed global index matrix, in CSR format
|
||||
* Transform floating values to integer index in histogram
|
||||
* This is a global histogram index.
|
||||
*
|
||||
* Transform floating values to integer index in histogram This is a global histogram
|
||||
* index for CPU histogram. On GPU ellpack page is used.
|
||||
*/
|
||||
struct GHistIndexMatrix {
|
||||
/*! \brief row pointer to rows by element position */
|
||||
@ -606,17 +538,15 @@ class ParallelGHistBuilder {
|
||||
*/
|
||||
class GHistBuilder {
|
||||
public:
|
||||
// initialize builder
|
||||
inline void Init(size_t nthread, uint32_t nbins) {
|
||||
nthread_ = nthread;
|
||||
nbins_ = nbins;
|
||||
}
|
||||
GHistBuilder() : nthread_{0}, nbins_{0} {}
|
||||
GHistBuilder(size_t nthread, uint32_t nbins) : nthread_{nthread}, nbins_{nbins} {}
|
||||
|
||||
// construct a histogram via histogram aggregation
|
||||
void BuildHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
const GHistIndexMatrix& gmat,
|
||||
GHistRow hist);
|
||||
GHistRow hist,
|
||||
bool isDense);
|
||||
// same, with feature grouping
|
||||
void BuildBlockHist(const std::vector<GradientPair>& gpair,
|
||||
const RowSetCollection::Elem row_indices,
|
||||
@ -625,7 +555,7 @@ class GHistBuilder {
|
||||
// construct a histogram via subtraction trick
|
||||
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent);
|
||||
|
||||
uint32_t GetNumBins() {
|
||||
uint32_t GetNumBins() const {
|
||||
return nbins_;
|
||||
}
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include <xgboost/data.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -29,7 +30,7 @@ class RowSetCollection {
|
||||
= default;
|
||||
Elem(const size_t* begin,
|
||||
const size_t* end,
|
||||
int node_id)
|
||||
int node_id = -1)
|
||||
: begin(begin), end(end), node_id(node_id) {}
|
||||
|
||||
inline size_t Size() const {
|
||||
@ -57,6 +58,13 @@ class RowSetCollection {
|
||||
<< "access element that is not in the set";
|
||||
return e;
|
||||
}
|
||||
|
||||
/*! \brief return corresponding element set given the node_id */
|
||||
inline Elem& operator[](unsigned node_id) {
|
||||
Elem& e = elem_of_each_node_[node_id];
|
||||
return e;
|
||||
}
|
||||
|
||||
// clear up things
|
||||
inline void Clear() {
|
||||
elem_of_each_node_.clear();
|
||||
@ -83,25 +91,18 @@ class RowSetCollection {
|
||||
}
|
||||
// split rowset into two
|
||||
inline void AddSplit(unsigned node_id,
|
||||
const std::vector<Split>& row_split_tloc,
|
||||
unsigned left_node_id,
|
||||
unsigned right_node_id) {
|
||||
unsigned right_node_id,
|
||||
size_t n_left,
|
||||
size_t n_right) {
|
||||
const Elem e = elem_of_each_node_[node_id];
|
||||
const auto nthread = static_cast<bst_omp_uint>(row_split_tloc.size());
|
||||
CHECK(e.begin != nullptr);
|
||||
size_t* all_begin = dmlc::BeginPtr(row_indices_);
|
||||
size_t* begin = all_begin + (e.begin - all_begin);
|
||||
|
||||
size_t* it = begin;
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it);
|
||||
it += row_split_tloc[tid].left.size();
|
||||
}
|
||||
size_t* split_pt = it;
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it);
|
||||
it += row_split_tloc[tid].right.size();
|
||||
}
|
||||
CHECK_EQ(n_left + n_right, e.Size());
|
||||
CHECK_LE(begin + n_left, e.end);
|
||||
CHECK_EQ(begin + n_left + n_right, e.end);
|
||||
|
||||
if (left_node_id >= elem_of_each_node_.size()) {
|
||||
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
|
||||
@ -110,12 +111,12 @@ class RowSetCollection {
|
||||
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
|
||||
}
|
||||
|
||||
elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id);
|
||||
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id);
|
||||
elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id);
|
||||
elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id);
|
||||
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
|
||||
}
|
||||
|
||||
// stores the row indices in the set
|
||||
// stores the row indexes in the set
|
||||
std::vector<size_t> row_indices_;
|
||||
|
||||
private:
|
||||
@ -123,6 +124,121 @@ class RowSetCollection {
|
||||
std::vector<Elem> elem_of_each_node_;
|
||||
};
|
||||
|
||||
|
||||
// The builder is required for samples partition to left and rights children for set of nodes
|
||||
// Responsible for:
|
||||
// 1) Effective memory allocation for intermediate results for multi-thread work
|
||||
// 2) Merging partial results produced by threads into original row set (row_set_collection_)
|
||||
// BlockSize is template to enable memory alignment easily with C++11 'alignas()' feature
|
||||
template<size_t BlockSize>
|
||||
class PartitionBuilder {
|
||||
public:
|
||||
template<typename Func>
|
||||
void Init(const size_t n_tasks, size_t n_nodes, Func funcNTaks) {
|
||||
left_right_nodes_sizes_.resize(n_nodes);
|
||||
blocks_offsets_.resize(n_nodes+1);
|
||||
|
||||
blocks_offsets_[0] = 0;
|
||||
for (size_t i = 1; i < n_nodes+1; ++i) {
|
||||
blocks_offsets_[i] = blocks_offsets_[i-1] + funcNTaks(i-1);
|
||||
}
|
||||
|
||||
if (n_tasks > max_n_tasks_) {
|
||||
mem_blocks_.resize(n_tasks);
|
||||
max_n_tasks_ = n_tasks;
|
||||
}
|
||||
}
|
||||
|
||||
common::Span<size_t> GetLeftBuffer(int nid, size_t begin, size_t end) {
|
||||
const size_t task_idx = GetTaskIdx(nid, begin);
|
||||
return { mem_blocks_.at(task_idx).left(), end - begin };
|
||||
}
|
||||
|
||||
common::Span<size_t> GetRightBuffer(int nid, size_t begin, size_t end) {
|
||||
const size_t task_idx = GetTaskIdx(nid, begin);
|
||||
return { mem_blocks_.at(task_idx).right(), end - begin };
|
||||
}
|
||||
|
||||
void SetNLeftElems(int nid, size_t begin, size_t end, size_t n_left) {
|
||||
size_t task_idx = GetTaskIdx(nid, begin);
|
||||
mem_blocks_.at(task_idx).n_left = n_left;
|
||||
}
|
||||
|
||||
void SetNRightElems(int nid, size_t begin, size_t end, size_t n_right) {
|
||||
size_t task_idx = GetTaskIdx(nid, begin);
|
||||
mem_blocks_.at(task_idx).n_right = n_right;
|
||||
}
|
||||
|
||||
|
||||
size_t GetNLeftElems(int nid) const {
|
||||
return left_right_nodes_sizes_[nid].first;
|
||||
}
|
||||
|
||||
size_t GetNRightElems(int nid) const {
|
||||
return left_right_nodes_sizes_[nid].second;
|
||||
}
|
||||
|
||||
// Each thread has partial results for some set of tree-nodes
|
||||
// The function decides order of merging partial results into final row set
|
||||
void CalculateRowOffsets() {
|
||||
for (size_t i = 0; i < blocks_offsets_.size()-1; ++i) {
|
||||
size_t n_left = 0;
|
||||
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) {
|
||||
mem_blocks_[j].n_offset_left = n_left;
|
||||
n_left += mem_blocks_[j].n_left;
|
||||
}
|
||||
size_t n_right = 0;
|
||||
for (size_t j = blocks_offsets_[i]; j < blocks_offsets_[i+1]; ++j) {
|
||||
mem_blocks_[j].n_offset_right = n_left + n_right;
|
||||
n_right += mem_blocks_[j].n_right;
|
||||
}
|
||||
left_right_nodes_sizes_[i] = {n_left, n_right};
|
||||
}
|
||||
}
|
||||
|
||||
void MergeToArray(int nid, size_t begin, size_t* rows_indexes) {
|
||||
size_t task_idx = GetTaskIdx(nid, begin);
|
||||
|
||||
size_t* left_result = rows_indexes + mem_blocks_[task_idx].n_offset_left;
|
||||
size_t* right_result = rows_indexes + mem_blocks_[task_idx].n_offset_right;
|
||||
|
||||
const size_t* left = mem_blocks_[task_idx].left();
|
||||
const size_t* right = mem_blocks_[task_idx].right();
|
||||
|
||||
std::copy_n(left, mem_blocks_[task_idx].n_left, left_result);
|
||||
std::copy_n(right, mem_blocks_[task_idx].n_right, right_result);
|
||||
}
|
||||
|
||||
protected:
|
||||
size_t GetTaskIdx(int nid, size_t begin) {
|
||||
return blocks_offsets_[nid] + begin / BlockSize;
|
||||
}
|
||||
|
||||
struct BlockInfo{
|
||||
size_t n_left;
|
||||
size_t n_right;
|
||||
|
||||
size_t n_offset_left;
|
||||
size_t n_offset_right;
|
||||
|
||||
size_t* left() {
|
||||
return &left_data_[0];
|
||||
}
|
||||
|
||||
size_t* right() {
|
||||
return &right_data_[0];
|
||||
}
|
||||
private:
|
||||
alignas(128) size_t left_data_[BlockSize];
|
||||
alignas(128) size_t right_data_[BlockSize];
|
||||
};
|
||||
std::vector<std::pair<size_t, size_t>> left_right_nodes_sizes_;
|
||||
std::vector<size_t> blocks_offsets_;
|
||||
std::vector<BlockInfo> mem_blocks_;
|
||||
size_t max_n_tasks_ = 0;
|
||||
};
|
||||
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -9,6 +9,8 @@
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
@ -20,11 +22,11 @@ class Range1d {
|
||||
CHECK_LT(begin, end);
|
||||
}
|
||||
|
||||
size_t begin() {
|
||||
size_t begin() const { // NOLINT
|
||||
return begin_;
|
||||
}
|
||||
|
||||
size_t end() {
|
||||
size_t end() const { // NOLINT
|
||||
return end_;
|
||||
}
|
||||
|
||||
|
||||
@ -239,17 +239,14 @@ void QuantileHistMaker::Builder::BuildNodeStats(
|
||||
builder_monitor_.Stop("BuildNodeStats");
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::EvaluateSplits(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const ColumnMatrix &column_matrix,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry> *temp_qexpand_depth) {
|
||||
EvaluateSplit(qexpand_depth_wise_, gmat, hist_, *p_fmat, *p_tree);
|
||||
|
||||
void QuantileHistMaker::Builder::AddSplitsToTree(
|
||||
const GHistIndexMatrix &gmat,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry>* nodes_for_apply_split,
|
||||
std::vector<ExpandEntry>* temp_qexpand_depth) {
|
||||
for (auto const& entry : qexpand_depth_wise_) {
|
||||
int nid = entry.nid;
|
||||
|
||||
@ -258,7 +255,17 @@ void QuantileHistMaker::Builder::EvaluateSplits(
|
||||
(param_.max_leaves > 0 && (*num_leaves) == param_.max_leaves)) {
|
||||
(*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate);
|
||||
} else {
|
||||
this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree);
|
||||
nodes_for_apply_split->push_back(entry);
|
||||
|
||||
NodeEntry& e = snode_[nid];
|
||||
bst_float left_leaf_weight =
|
||||
spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate;
|
||||
bst_float right_leaf_weight =
|
||||
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
|
||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
|
||||
|
||||
int left_id = (*p_tree)[nid].LeftChild();
|
||||
int right_id = (*p_tree)[nid].RightChild();
|
||||
temp_qexpand_depth->push_back(ExpandEntry(left_id, right_id,
|
||||
@ -271,6 +278,24 @@ void QuantileHistMaker::Builder::EvaluateSplits(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void QuantileHistMaker::Builder::EvaluateAndApplySplits(
|
||||
const GHistIndexMatrix &gmat,
|
||||
const ColumnMatrix &column_matrix,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry> *temp_qexpand_depth) {
|
||||
EvaluateSplits(qexpand_depth_wise_, gmat, hist_, *p_tree);
|
||||
|
||||
std::vector<ExpandEntry> nodes_for_apply_split;
|
||||
AddSplitsToTree(gmat, p_tree, num_leaves, depth, timestamp,
|
||||
&nodes_for_apply_split, temp_qexpand_depth);
|
||||
|
||||
ApplySplit(nodes_for_apply_split, gmat, column_matrix, hist_, p_tree);
|
||||
}
|
||||
|
||||
// Split nodes to 2 sets depending on amount of rows in each node
|
||||
// Histograms for small nodes will be built explicitly
|
||||
// Histograms for big nodes will be built by 'Subtraction Trick'
|
||||
@ -335,7 +360,7 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise(
|
||||
SyncHistograms(starting_index, sync_count, p_tree);
|
||||
|
||||
BuildNodeStats(gmat, p_fmat, p_tree, gpair_h);
|
||||
EvaluateSplits(gmat, column_matrix, p_fmat, p_tree, &num_leaves, depth, ×tamp,
|
||||
EvaluateAndApplySplits(gmat, column_matrix, p_tree, &num_leaves, depth, ×tamp,
|
||||
&temp_qexpand_depth);
|
||||
// clean up
|
||||
qexpand_depth_wise_.clear();
|
||||
@ -367,7 +392,7 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
|
||||
this->InitNewNode(ExpandEntry::kRootNid, gmat, gpair_h, *p_fmat, *p_tree);
|
||||
|
||||
this->EvaluateSplit({node}, gmat, hist_, *p_fmat, *p_tree);
|
||||
this->EvaluateSplits({node}, gmat, hist_, *p_tree);
|
||||
node.loss_chg = snode_[ExpandEntry::kRootNid].best.loss_chg;
|
||||
|
||||
qexpand_loss_guided_->push(node);
|
||||
@ -377,12 +402,19 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
const ExpandEntry candidate = qexpand_loss_guided_->top();
|
||||
const int nid = candidate.nid;
|
||||
qexpand_loss_guided_->pop();
|
||||
if (candidate.loss_chg <= kRtEps
|
||||
|| (param_.max_depth > 0 && candidate.depth == param_.max_depth)
|
||||
|| (param_.max_leaves > 0 && num_leaves == param_.max_leaves) ) {
|
||||
if (candidate.IsValid(param_, num_leaves)) {
|
||||
(*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate);
|
||||
} else {
|
||||
this->ApplySplit(nid, gmat, column_matrix, hist_, *p_fmat, p_tree);
|
||||
NodeEntry& e = snode_[nid];
|
||||
bst_float left_leaf_weight =
|
||||
spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate;
|
||||
bst_float right_leaf_weight =
|
||||
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
|
||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
|
||||
|
||||
this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree);
|
||||
|
||||
const int cleft = (*p_tree)[nid].LeftChild();
|
||||
const int cright = (*p_tree)[nid].RightChild();
|
||||
@ -410,7 +442,7 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
snode_[cleft].weight, snode_[cright].weight);
|
||||
interaction_constraints_.Split(nid, featureid, cleft, cright);
|
||||
|
||||
this->EvaluateSplit({left_node, right_node}, gmat, hist_, *p_fmat, *p_tree);
|
||||
this->EvaluateSplits({left_node, right_node}, gmat, hist_, *p_tree);
|
||||
left_node.loss_chg = snode_[cleft].best.loss_chg;
|
||||
right_node.loss_chg = snode_[cright].best.loss_chg;
|
||||
|
||||
@ -473,7 +505,14 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(
|
||||
|
||||
CHECK_GT(out_preds.size(), 0U);
|
||||
|
||||
for (const RowSetCollection::Elem rowset : row_set_collection_) {
|
||||
size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin();
|
||||
|
||||
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
|
||||
return row_set_collection_[node].Size();
|
||||
}, 1024);
|
||||
|
||||
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
|
||||
const RowSetCollection::Elem rowset = row_set_collection_[node];
|
||||
if (rowset.begin != nullptr && rowset.end != nullptr) {
|
||||
int nid = rowset.node_id;
|
||||
bst_float leaf_value;
|
||||
@ -487,11 +526,11 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache(
|
||||
}
|
||||
leaf_value = (*p_last_tree_)[nid].LeafValue();
|
||||
|
||||
for (const size_t* it = rowset.begin; it < rowset.end; ++it) {
|
||||
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds[*it] += leaf_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
builder_monitor_.Stop("UpdatePredictionCache");
|
||||
return true;
|
||||
@ -526,7 +565,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
||||
{
|
||||
this->nthread_ = omp_get_num_threads();
|
||||
}
|
||||
hist_builder_.Init(this->nthread_, nbins);
|
||||
hist_builder_ = GHistBuilder(this->nthread_, nbins);
|
||||
|
||||
std::vector<size_t>& row_indices = row_set_collection_.row_indices_;
|
||||
row_indices.resize(info.num_row_);
|
||||
@ -674,12 +713,11 @@ bool QuantileHistMaker::Builder::SplitContainsMissingValues(const GradStats e,
|
||||
}
|
||||
|
||||
// nodes_set - set of nodes to be processed in parallel
|
||||
void QuantileHistMaker::Builder::EvaluateSplit(const std::vector<ExpandEntry>& nodes_set,
|
||||
void QuantileHistMaker::Builder::EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
builder_monitor_.Start("EvaluateSplit");
|
||||
builder_monitor_.Start("EvaluateSplits");
|
||||
|
||||
const size_t n_nodes_in_set = nodes_set.size();
|
||||
const size_t nthread = std::max(1, this->nthread_);
|
||||
@ -714,11 +752,11 @@ void QuantileHistMaker::Builder::EvaluateSplit(const std::vector<ExpandEntry>& n
|
||||
for (auto idx_in_feature_set = r.begin(); idx_in_feature_set < r.end(); ++idx_in_feature_set) {
|
||||
const auto fid = features_sets[nid_in_set]->ConstHostVector()[idx_in_feature_set];
|
||||
if (interaction_constraints_.Query(nid, fid)) {
|
||||
auto grad_stats = this->EnumerateSplit<+1>(gmat, node_hist, snode_[nid], fmat.Info(),
|
||||
&best_split_tloc_[nthread*nid_in_set + tid], fid, nid);
|
||||
auto grad_stats = this->EnumerateSplit<+1>(gmat, node_hist, snode_[nid],
|
||||
&best_split_tloc_[nthread*nid_in_set + tid], fid, nid);
|
||||
if (SplitContainsMissingValues(grad_stats, snode_[nid])) {
|
||||
this->EnumerateSplit<-1>(gmat, node_hist, snode_[nid], fmat.Info(),
|
||||
&best_split_tloc_[nthread*nid_in_set + tid], fid, nid);
|
||||
this->EnumerateSplit<-1>(gmat, node_hist, snode_[nid],
|
||||
&best_split_tloc_[nthread*nid_in_set + tid], fid, nid);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -732,198 +770,240 @@ void QuantileHistMaker::Builder::EvaluateSplit(const std::vector<ExpandEntry>& n
|
||||
}
|
||||
}
|
||||
|
||||
builder_monitor_.Stop("EvaluateSplit");
|
||||
builder_monitor_.Stop("EvaluateSplits");
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::ApplySplit(int nid,
|
||||
// split row indexes (rid_span) to 2 parts (left_part, right_part) depending
|
||||
// on comparison of indexes values (idx_span) and split point (split_cond)
|
||||
// Handle dense columns
|
||||
// Analog of std::stable_partition, but in no-inplace manner
|
||||
template <bool default_left>
|
||||
inline std::pair<size_t, size_t> PartitionDenseKernel(
|
||||
common::Span<const size_t> rid_span, common::Span<const uint32_t> idx_span,
|
||||
const int32_t split_cond, const uint32_t offset,
|
||||
common::Span<size_t> left_part, common::Span<size_t> right_part) {
|
||||
const uint32_t* idx = idx_span.data();
|
||||
size_t* p_left_part = left_part.data();
|
||||
size_t* p_right_part = right_part.data();
|
||||
size_t nleft_elems = 0;
|
||||
size_t nright_elems = 0;
|
||||
|
||||
const uint32_t missing_val = std::numeric_limits<uint32_t>::max();
|
||||
|
||||
for (auto rid : rid_span) {
|
||||
if (idx[rid] == missing_val) {
|
||||
if (default_left) {
|
||||
p_left_part[nleft_elems++] = rid;
|
||||
} else {
|
||||
p_right_part[nright_elems++] = rid;
|
||||
}
|
||||
} else {
|
||||
if (static_cast<int32_t>(idx[rid] + offset) <= split_cond) {
|
||||
p_left_part[nleft_elems++] = rid;
|
||||
} else {
|
||||
p_right_part[nright_elems++] = rid;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {nleft_elems, nright_elems};
|
||||
}
|
||||
|
||||
// Split row indexes (rid_span) to 2 parts (left_part, right_part) depending
|
||||
// on comparison of indexes values (idx_span) and split point (split_cond).
|
||||
// Handle sparse columns
|
||||
template<bool default_left>
|
||||
inline std::pair<size_t, size_t> PartitionSparseKernel(
|
||||
common::Span<const size_t> rid_span, const int32_t split_cond, const Column& column,
|
||||
common::Span<size_t> left_part, common::Span<size_t> right_part) {
|
||||
size_t* p_left_part = left_part.data();
|
||||
size_t* p_right_part = right_part.data();
|
||||
|
||||
size_t nleft_elems = 0;
|
||||
size_t nright_elems = 0;
|
||||
|
||||
if (rid_span.size()) { // ensure that rid_span is nonempty range
|
||||
// search first nonzero row with index >= rid_span.front()
|
||||
const size_t* p = std::lower_bound(column.GetRowData(),
|
||||
column.GetRowData() + column.Size(),
|
||||
rid_span.front());
|
||||
|
||||
if (p != column.GetRowData() + column.Size() && *p <= rid_span.back()) {
|
||||
size_t cursor = p - column.GetRowData();
|
||||
|
||||
for (auto rid : rid_span) {
|
||||
while (cursor < column.Size()
|
||||
&& column.GetRowIdx(cursor) < rid
|
||||
&& column.GetRowIdx(cursor) <= rid_span.back()) {
|
||||
++cursor;
|
||||
}
|
||||
if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) {
|
||||
const uint32_t rbin = column.GetFeatureBinIdx(cursor);
|
||||
if (static_cast<int32_t>(rbin + column.GetBaseIdx()) <= split_cond) {
|
||||
p_left_part[nleft_elems++] = rid;
|
||||
} else {
|
||||
p_right_part[nright_elems++] = rid;
|
||||
}
|
||||
++cursor;
|
||||
} else {
|
||||
// missing value
|
||||
if (default_left) {
|
||||
p_left_part[nleft_elems++] = rid;
|
||||
} else {
|
||||
p_right_part[nright_elems++] = rid;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // all rows in rid_span have missing values
|
||||
if (default_left) {
|
||||
std::copy(rid_span.begin(), rid_span.end(), p_left_part);
|
||||
nleft_elems = rid_span.size();
|
||||
} else {
|
||||
std::copy(rid_span.begin(), rid_span.end(), p_right_part);
|
||||
nright_elems = rid_span.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {nleft_elems, nright_elems};
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::PartitionKernel(
|
||||
const size_t node_in_set, const size_t nid, common::Range1d range,
|
||||
const int32_t split_cond, const ColumnMatrix& column_matrix,
|
||||
const GHistIndexMatrix& gmat, const RegTree& tree) {
|
||||
const size_t* rid = row_set_collection_[nid].begin;
|
||||
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
|
||||
common::Span<size_t> left = partition_builder_.GetLeftBuffer(node_in_set,
|
||||
range.begin(), range.end());
|
||||
common::Span<size_t> right = partition_builder_.GetRightBuffer(node_in_set,
|
||||
range.begin(), range.end());
|
||||
const bst_uint fid = tree[nid].SplitIndex();
|
||||
const bool default_left = tree[nid].DefaultLeft();
|
||||
const auto column = column_matrix.GetColumn(fid);
|
||||
const uint32_t offset = column.GetBaseIdx();
|
||||
common::Span<const uint32_t> idx_spin = column.GetFeatureBinIdxPtr();
|
||||
|
||||
std::pair<size_t, size_t> child_nodes_sizes;
|
||||
|
||||
if (column.GetType() == xgboost::common::kDenseColumn) {
|
||||
if (default_left) {
|
||||
child_nodes_sizes = PartitionDenseKernel<true>(
|
||||
rid_span, idx_spin, split_cond, offset, left, right);
|
||||
} else {
|
||||
child_nodes_sizes = PartitionDenseKernel<false>(
|
||||
rid_span, idx_spin, split_cond, offset, left, right);
|
||||
}
|
||||
} else {
|
||||
if (default_left) {
|
||||
child_nodes_sizes = PartitionSparseKernel<true>(rid_span, split_cond, column, left, right);
|
||||
} else {
|
||||
child_nodes_sizes = PartitionSparseKernel<false>(rid_span, split_cond, column, left, right);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t n_left = child_nodes_sizes.first;
|
||||
const size_t n_right = child_nodes_sizes.second;
|
||||
|
||||
partition_builder_.SetNLeftElems(node_in_set, range.begin(), range.end(), n_left);
|
||||
partition_builder_.SetNRightElems(node_in_set, range.begin(), range.end(), n_right);
|
||||
}
|
||||
|
||||
|
||||
void QuantileHistMaker::Builder::FindSplitConditions(const std::vector<ExpandEntry>& nodes,
|
||||
const RegTree& tree,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<int32_t>* split_conditions) {
|
||||
const size_t n_nodes = nodes.size();
|
||||
split_conditions->resize(n_nodes);
|
||||
|
||||
for (size_t i = 0; i < nodes.size(); ++i) {
|
||||
const int32_t nid = nodes[i].nid;
|
||||
const bst_uint fid = tree[nid].SplitIndex();
|
||||
const bst_float split_pt = tree[nid].SplitCond();
|
||||
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
|
||||
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
|
||||
int32_t split_cond = -1;
|
||||
// convert floating-point split_pt into corresponding bin_id
|
||||
// split_cond = -1 indicates that split_pt is less than all known cut points
|
||||
CHECK_LT(upper_bound,
|
||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
|
||||
if (split_pt == gmat.cut.Values()[i]) {
|
||||
split_cond = static_cast<int32_t>(i);
|
||||
}
|
||||
}
|
||||
(*split_conditions)[i] = split_cond;
|
||||
}
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes,
|
||||
RegTree* p_tree) {
|
||||
const size_t n_nodes = nodes.size();
|
||||
for (size_t i = 0; i < n_nodes; ++i) {
|
||||
const int32_t nid = nodes[i].nid;
|
||||
const size_t n_left = partition_builder_.GetNLeftElems(i);
|
||||
const size_t n_right = partition_builder_.GetNRightElems(i);
|
||||
|
||||
row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(),
|
||||
(*p_tree)[nid].RightChild(), n_left, n_right);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void QuantileHistMaker::Builder::ApplySplit(const std::vector<ExpandEntry> nodes,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
RegTree* p_tree) {
|
||||
builder_monitor_.Start("ApplySplit");
|
||||
// TODO(hcho3): support feature sampling by levels
|
||||
|
||||
/* 1. Create child nodes */
|
||||
NodeEntry& e = snode_[nid];
|
||||
bst_float left_leaf_weight =
|
||||
spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate;
|
||||
bst_float right_leaf_weight =
|
||||
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
|
||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
|
||||
// 1. Find split condition for each split
|
||||
const size_t n_nodes = nodes.size();
|
||||
std::vector<int32_t> split_conditions;
|
||||
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
|
||||
|
||||
/* 2. Categorize member rows */
|
||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
row_split_tloc_.resize(nthread);
|
||||
for (bst_omp_uint i = 0; i < nthread; ++i) {
|
||||
row_split_tloc_[i].left.clear();
|
||||
row_split_tloc_[i].right.clear();
|
||||
}
|
||||
const bool default_left = (*p_tree)[nid].DefaultLeft();
|
||||
const bst_uint fid = (*p_tree)[nid].SplitIndex();
|
||||
const bst_float split_pt = (*p_tree)[nid].SplitCond();
|
||||
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
|
||||
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
|
||||
int32_t split_cond = -1;
|
||||
// convert floating-point split_pt into corresponding bin_id
|
||||
// split_cond = -1 indicates that split_pt is less than all known cut points
|
||||
CHECK_LT(upper_bound,
|
||||
static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
|
||||
for (uint32_t i = lower_bound; i < upper_bound; ++i) {
|
||||
if (split_pt == gmat.cut.Values()[i]) {
|
||||
split_cond = static_cast<int32_t>(i);
|
||||
}
|
||||
}
|
||||
// 2.1 Create a blocked space of size SUM(samples in each node)
|
||||
common::BlockedSpace2d space(n_nodes, [&](size_t node_in_set) {
|
||||
int32_t nid = nodes[node_in_set].nid;
|
||||
return row_set_collection_[nid].Size();
|
||||
}, kPartitionBlockSize);
|
||||
|
||||
const auto& rowset = row_set_collection_[nid];
|
||||
// 2.2 Initialize the partition builder
|
||||
// allocate buffers for storage intermediate results by each thread
|
||||
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
const size_t size = row_set_collection_[nid].Size();
|
||||
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
|
||||
return n_tasks;
|
||||
});
|
||||
|
||||
Column column = column_matrix.GetColumn(fid);
|
||||
if (column.GetType() == xgboost::common::kDenseColumn) {
|
||||
ApplySplitDenseData(rowset, gmat, &row_split_tloc_, column, split_cond,
|
||||
default_left);
|
||||
} else {
|
||||
ApplySplitSparseData(rowset, gmat, &row_split_tloc_, column, lower_bound,
|
||||
upper_bound, split_cond, default_left);
|
||||
}
|
||||
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
|
||||
// Store results in intermediate buffers from partition_builder_
|
||||
common::ParallelFor2d(space, this->nthread_, [&](size_t node_in_set, common::Range1d r) {
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
PartitionKernel(node_in_set, nid, r,
|
||||
split_conditions[node_in_set], column_matrix, gmat, *p_tree);
|
||||
});
|
||||
|
||||
// 3. Compute offsets to copy blocks of row-indexes
|
||||
// from partition_builder_ to row_set_collection_
|
||||
partition_builder_.CalculateRowOffsets();
|
||||
|
||||
// 4. Copy elements from partition_builder_ to row_set_collection_ back
|
||||
// with updated row-indexes for each tree-node
|
||||
common::ParallelFor2d(space, this->nthread_, [&](size_t node_in_set, common::Range1d r) {
|
||||
const int32_t nid = nodes[node_in_set].nid;
|
||||
partition_builder_.MergeToArray(node_in_set, r.begin(),
|
||||
const_cast<size_t*>(row_set_collection_[nid].begin));
|
||||
});
|
||||
|
||||
// 5. Add info about splits into row_set_collection_
|
||||
AddSplitsToRowSet(nodes, p_tree);
|
||||
|
||||
row_set_collection_.AddSplit(
|
||||
nid, row_split_tloc_, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild());
|
||||
builder_monitor_.Stop("ApplySplit");
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::ApplySplitDenseData(
|
||||
const RowSetCollection::Elem rowset,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
||||
const Column& column,
|
||||
bst_int split_cond,
|
||||
bool default_left) {
|
||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||
constexpr int kUnroll = 8; // loop unrolling factor
|
||||
const size_t nrows = rowset.end - rowset.begin;
|
||||
const size_t rest = nrows % kUnroll;
|
||||
|
||||
#pragma omp parallel for num_threads(nthread_) schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nrows - rest; i += kUnroll) {
|
||||
const bst_uint tid = omp_get_thread_num();
|
||||
auto& left = row_split_tloc[tid].left;
|
||||
auto& right = row_split_tloc[tid].right;
|
||||
size_t rid[kUnroll];
|
||||
uint32_t rbin[kUnroll];
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
rid[k] = rowset.begin[i + k];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
rbin[k] = column.GetFeatureBinIdx(rid[k]);
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) { // NOLINT
|
||||
if (rbin[k] == std::numeric_limits<uint32_t>::max()) { // missing value
|
||||
if (default_left) {
|
||||
left.push_back(rid[k]);
|
||||
} else {
|
||||
right.push_back(rid[k]);
|
||||
}
|
||||
} else {
|
||||
if (static_cast<int32_t>(rbin[k] + column.GetBaseIdx()) <= split_cond) {
|
||||
left.push_back(rid[k]);
|
||||
} else {
|
||||
right.push_back(rid[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t i = nrows - rest; i < nrows; ++i) {
|
||||
auto& left = row_split_tloc[nthread_-1].left;
|
||||
auto& right = row_split_tloc[nthread_-1].right;
|
||||
const size_t rid = rowset.begin[i];
|
||||
const uint32_t rbin = column.GetFeatureBinIdx(rid);
|
||||
if (rbin == std::numeric_limits<uint32_t>::max()) { // missing value
|
||||
if (default_left) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
}
|
||||
} else {
|
||||
if (static_cast<int32_t>(rbin + column.GetBaseIdx()) <= split_cond) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::ApplySplitSparseData(
|
||||
const RowSetCollection::Elem rowset,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
||||
const Column& column,
|
||||
bst_uint lower_bound,
|
||||
bst_uint upper_bound,
|
||||
bst_int split_cond,
|
||||
bool default_left) {
|
||||
std::vector<RowSetCollection::Split>& row_split_tloc = *p_row_split_tloc;
|
||||
const size_t nrows = rowset.end - rowset.begin;
|
||||
|
||||
#pragma omp parallel num_threads(nthread_)
|
||||
{
|
||||
const auto tid = static_cast<size_t>(omp_get_thread_num());
|
||||
const size_t ibegin = tid * nrows / nthread_;
|
||||
const size_t iend = (tid + 1) * nrows / nthread_;
|
||||
if (ibegin < iend) { // ensure that [ibegin, iend) is nonempty range
|
||||
// search first nonzero row with index >= rowset[ibegin]
|
||||
const size_t* p = std::lower_bound(column.GetRowData(),
|
||||
column.GetRowData() + column.Size(),
|
||||
rowset.begin[ibegin]);
|
||||
|
||||
auto& left = row_split_tloc[tid].left;
|
||||
auto& right = row_split_tloc[tid].right;
|
||||
if (p != column.GetRowData() + column.Size() && *p <= rowset.begin[iend - 1]) {
|
||||
size_t cursor = p - column.GetRowData();
|
||||
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
const size_t rid = rowset.begin[i];
|
||||
while (cursor < column.Size()
|
||||
&& column.GetRowIdx(cursor) < rid
|
||||
&& column.GetRowIdx(cursor) <= rowset.begin[iend - 1]) {
|
||||
++cursor;
|
||||
}
|
||||
if (cursor < column.Size() && column.GetRowIdx(cursor) == rid) {
|
||||
const uint32_t rbin = column.GetFeatureBinIdx(cursor);
|
||||
if (static_cast<int32_t>(rbin + column.GetBaseIdx()) <= split_cond) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
}
|
||||
++cursor;
|
||||
} else {
|
||||
// missing value
|
||||
if (default_left) {
|
||||
left.push_back(rid);
|
||||
} else {
|
||||
right.push_back(rid);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // all rows in [ibegin, iend) have missing values
|
||||
if (default_left) {
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
const size_t rid = rowset.begin[i];
|
||||
left.push_back(rid);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
const size_t rid = rowset.begin[i];
|
||||
right.push_back(rid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
@ -979,15 +1059,10 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
// Enumerate the split values of specific feature.
|
||||
// Returns the sum of gradients corresponding to the data points that contains a non-missing value
|
||||
// for the particular feature fid.
|
||||
template<int d_step>
|
||||
template <int d_step>
|
||||
GradStats QuantileHistMaker::Builder::EnumerateSplit(
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistRow& hist,
|
||||
const NodeEntry& snode,
|
||||
const MetaInfo& info,
|
||||
SplitEntry* p_best,
|
||||
bst_uint fid,
|
||||
bst_uint nodeID) {
|
||||
const GHistIndexMatrix &gmat, const GHistRow &hist, const NodeEntry &snode,
|
||||
SplitEntry *p_best, bst_uint fid, bst_uint nodeID) const {
|
||||
CHECK(d_step == +1 || d_step == -1);
|
||||
|
||||
// aliases
|
||||
|
||||
@ -161,7 +161,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
if (param_.enable_feature_grouping > 0) {
|
||||
hist_builder_.BuildBlockHist(gpair, row_indices, gmatb, hist);
|
||||
} else {
|
||||
hist_builder_.BuildHist(gpair, row_indices, gmat, hist);
|
||||
hist_builder_.BuildHist(gpair, row_indices, gmat, hist, data_layout_ != kSparseData);
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,6 +186,13 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
unsigned timestamp;
|
||||
ExpandEntry(int nid, int sibling_nid, int depth, bst_float loss_chg, unsigned tstmp):
|
||||
nid(nid), sibling_nid(sibling_nid), depth(depth), loss_chg(loss_chg), timestamp(tstmp) {}
|
||||
|
||||
bool IsValid(TrainParam const& param, int32_t num_leaves) const {
|
||||
bool ret = loss_chg <= kRtEps ||
|
||||
(param.max_depth > 0 && this->depth == param.max_depth) ||
|
||||
(param.max_leaves > 0 && num_leaves == param.max_leaves);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
// initialize temp data structure
|
||||
@ -194,34 +201,27 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree);
|
||||
|
||||
void EvaluateSplit(const std::vector<ExpandEntry>& nodes_set,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree);
|
||||
void EvaluateSplits(const std::vector<ExpandEntry>& nodes_set,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const HistCollection& hist,
|
||||
const RegTree& tree);
|
||||
|
||||
void ApplySplit(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection& hist,
|
||||
const DMatrix& fmat,
|
||||
RegTree* p_tree);
|
||||
void ApplySplit(std::vector<ExpandEntry> nodes,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const ColumnMatrix& column_matrix,
|
||||
const HistCollection& hist,
|
||||
RegTree* p_tree);
|
||||
|
||||
void ApplySplitDenseData(const RowSetCollection::Elem rowset,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
||||
const Column& column,
|
||||
bst_int split_cond,
|
||||
bool default_left);
|
||||
void PartitionKernel(const size_t node_in_set, const size_t nid, common::Range1d range,
|
||||
const int32_t split_cond,
|
||||
const ColumnMatrix& column_matrix, const GHistIndexMatrix& gmat,
|
||||
const RegTree& tree);
|
||||
|
||||
void ApplySplitSparseData(const RowSetCollection::Elem rowset,
|
||||
const GHistIndexMatrix& gmat,
|
||||
std::vector<RowSetCollection::Split>* p_row_split_tloc,
|
||||
const Column& column,
|
||||
bst_uint lower_bound,
|
||||
bst_uint upper_bound,
|
||||
bst_int split_cond,
|
||||
bool default_left);
|
||||
void AddSplitsToRowSet(const std::vector<ExpandEntry>& nodes, RegTree* p_tree);
|
||||
|
||||
|
||||
void FindSplitConditions(const std::vector<ExpandEntry>& nodes, const RegTree& tree,
|
||||
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions);
|
||||
|
||||
void InitNewNode(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
@ -232,15 +232,10 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
// Enumerate the split values of specific feature
|
||||
// Returns the sum of gradients corresponding to the data points that contains a non-missing
|
||||
// value for the particular feature fid.
|
||||
template<int d_step>
|
||||
GradStats EnumerateSplit(
|
||||
const GHistIndexMatrix& gmat,
|
||||
const GHistRow& hist,
|
||||
const NodeEntry& snode,
|
||||
const MetaInfo& info,
|
||||
SplitEntry* p_best,
|
||||
bst_uint fid,
|
||||
bst_uint nodeID);
|
||||
template <int d_step>
|
||||
GradStats EnumerateSplit(const GHistIndexMatrix &gmat, const GHistRow &hist,
|
||||
const NodeEntry &snode, SplitEntry *p_best,
|
||||
bst_uint fid, bst_uint nodeID) const;
|
||||
|
||||
// if sum of statistics for non-missing values in the node
|
||||
// is equal to sum of statistics for all values:
|
||||
@ -286,14 +281,22 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
RegTree *p_tree,
|
||||
const std::vector<GradientPair> &gpair_h);
|
||||
|
||||
void EvaluateSplits(const GHistIndexMatrix &gmat,
|
||||
const ColumnMatrix &column_matrix,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry> *temp_qexpand_depth);
|
||||
void EvaluateAndApplySplits(const GHistIndexMatrix &gmat,
|
||||
const ColumnMatrix &column_matrix,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry> *temp_qexpand_depth);
|
||||
|
||||
void AddSplitsToTree(
|
||||
const GHistIndexMatrix &gmat,
|
||||
RegTree *p_tree,
|
||||
int *num_leaves,
|
||||
int depth,
|
||||
unsigned *timestamp,
|
||||
std::vector<ExpandEntry>* nodes_for_apply_split,
|
||||
std::vector<ExpandEntry>* temp_qexpand_depth);
|
||||
|
||||
void ExpandWithLossGuide(const GHistIndexMatrix& gmat,
|
||||
const GHistIndexBlockMatrix& gmatb,
|
||||
@ -335,6 +338,9 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
std::unique_ptr<SplitEvaluator> spliteval_;
|
||||
FeatureInteractionConstraintHost interaction_constraints_;
|
||||
|
||||
static constexpr size_t kPartitionBlockSize = 2048;
|
||||
common::PartitionBuilder<kPartitionBlockSize> partition_builder_;
|
||||
|
||||
// back pointers to tree and data matrix
|
||||
const RegTree* p_last_tree_;
|
||||
DMatrix const* const p_last_fmat_;
|
||||
|
||||
76
tests/cpp/common/test_partition_builder.cc
Executable file
76
tests/cpp/common/test_partition_builder.cc
Executable file
@ -0,0 +1,76 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "../../../src/common/row_set.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
TEST(PartitionBuilder, BasicTest) {
|
||||
constexpr size_t kBlockSize = 16;
|
||||
constexpr size_t kNodes = 5;
|
||||
constexpr size_t kTasks = 3 + 5 + 10 + 1 + 2;
|
||||
|
||||
std::vector<size_t> tasks = { 3, 5, 10, 1, 2 };
|
||||
|
||||
PartitionBuilder<kBlockSize> builder;
|
||||
builder.Init(kTasks, kNodes, [&](size_t i) {
|
||||
return tasks[i];
|
||||
});
|
||||
|
||||
std::vector<size_t> rows_for_left_node = { 2, 12, 0, 16, 8 };
|
||||
|
||||
for(size_t nid = 0; nid < kNodes; ++nid) {
|
||||
size_t value_left = 0;
|
||||
size_t value_right = 0;
|
||||
|
||||
size_t left_total = tasks[nid] * rows_for_left_node[nid];
|
||||
|
||||
for(size_t j = 0; j < tasks[nid]; ++j) {
|
||||
size_t begin = kBlockSize*j;
|
||||
size_t end = kBlockSize*(j+1);
|
||||
|
||||
auto left = builder.GetLeftBuffer(nid, begin, end);
|
||||
auto right = builder.GetRightBuffer(nid, begin, end);
|
||||
|
||||
size_t n_left = rows_for_left_node[nid];
|
||||
size_t n_right = kBlockSize - rows_for_left_node[nid];
|
||||
|
||||
for(size_t i = 0; i < n_left; i++) {
|
||||
left[i] = value_left++;
|
||||
}
|
||||
|
||||
for(size_t i = 0; i < n_right; i++) {
|
||||
right[i] = left_total + value_right++;
|
||||
}
|
||||
|
||||
builder.SetNLeftElems(nid, begin, end, n_left);
|
||||
builder.SetNRightElems(nid, begin, end, n_right);
|
||||
}
|
||||
}
|
||||
builder.CalculateRowOffsets();
|
||||
|
||||
std::vector<size_t> v(*std::max_element(tasks.begin(), tasks.end()) * kBlockSize);
|
||||
|
||||
for(size_t nid = 0; nid < kNodes; ++nid) {
|
||||
|
||||
for(size_t j = 0; j < tasks[nid]; ++j) {
|
||||
builder.MergeToArray(nid, kBlockSize*j, v.data());
|
||||
}
|
||||
|
||||
for(size_t j = 0; j < tasks[nid] * kBlockSize; ++j) {
|
||||
ASSERT_EQ(v[j], j);
|
||||
}
|
||||
size_t n_left = builder.GetNLeftElems(nid);
|
||||
size_t n_right = builder.GetNRightElems(nid);
|
||||
|
||||
ASSERT_EQ(n_left, rows_for_left_node[nid] * tasks[nid]);
|
||||
ASSERT_EQ(n_right, (kBlockSize - rows_for_left_node[nid]) * tasks[nid]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -213,7 +213,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
/* Now compare against result given by EvaluateSplit() */
|
||||
ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
|
||||
tree.GetDepth(0), snode_[0].best.loss_chg, 0);
|
||||
RealImpl::EvaluateSplit({node}, gmat, hist_, *(*dmat), tree);
|
||||
RealImpl::EvaluateSplits({node}, gmat, hist_, tree);
|
||||
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
|
||||
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user