Optimize ‘hist’ for multi-core CPU (#4529)
* Initial performance optimizations for xgboost * remove includes * revert float->double * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * fix for CI * Check existence of _mm_prefetch and __builtin_prefetch * Fix lint * optimizations for CPU * appling comments in review * add some comments, code refactoring * fixing issues in CI * adding runtime checks * remove 1 extra check * remove extra checks in BuildHist * remove checks * add debug info * added debug info * revert changes * added comments * Apply suggestions from code review Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu> * apply review comments * Remove unused function CreateNewNodes() * Add descriptive comment on node_idx variable in QuantileHistMaker::Builder::BuildHistsBatch()
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
abffbe014e
commit
4d6590be3c
@@ -27,10 +27,10 @@ class RowSetCollection {
|
||||
// id of node associated with this instance set; -1 means uninitialized
|
||||
Elem()
|
||||
= default;
|
||||
Elem(const size_t* begin,
|
||||
const size_t* end,
|
||||
int node_id)
|
||||
: begin(begin), end(end), node_id(node_id) {}
|
||||
Elem(const size_t* begin_,
|
||||
const size_t* end_,
|
||||
int node_id_)
|
||||
: begin(begin_), end(end_), node_id(node_id_) {}
|
||||
|
||||
inline size_t Size() const {
|
||||
return end - begin;
|
||||
@@ -42,6 +42,10 @@ class RowSetCollection {
|
||||
std::vector<size_t> right;
|
||||
};
|
||||
|
||||
size_t Size(unsigned node_id) {
|
||||
return elem_of_each_node_[node_id].Size();
|
||||
}
|
||||
|
||||
inline std::vector<Elem>::const_iterator begin() const { // NOLINT
|
||||
return elem_of_each_node_.begin();
|
||||
}
|
||||
@@ -51,12 +55,12 @@ class RowSetCollection {
|
||||
}
|
||||
|
||||
/*! \brief return corresponding element set given the node_id */
|
||||
inline const Elem& operator[](unsigned node_id) const {
|
||||
const Elem& e = elem_of_each_node_[node_id];
|
||||
CHECK(e.begin != nullptr)
|
||||
<< "access element that is not in the set";
|
||||
inline Elem operator[](unsigned node_id) const {
|
||||
const Elem e = elem_of_each_node_[node_id];
|
||||
return e;
|
||||
}
|
||||
|
||||
|
||||
// clear up things
|
||||
inline void Clear() {
|
||||
elem_of_each_node_.clear();
|
||||
@@ -81,38 +85,29 @@ class RowSetCollection {
|
||||
const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
|
||||
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
|
||||
}
|
||||
|
||||
// split rowset into two
|
||||
inline void AddSplit(unsigned node_id,
|
||||
const std::vector<Split>& row_split_tloc,
|
||||
size_t iLeft,
|
||||
unsigned left_node_id,
|
||||
unsigned right_node_id) {
|
||||
const Elem e = elem_of_each_node_[node_id];
|
||||
const auto nthread = static_cast<bst_omp_uint>(row_split_tloc.size());
|
||||
CHECK(e.begin != nullptr);
|
||||
size_t* all_begin = dmlc::BeginPtr(row_indices_);
|
||||
size_t* begin = all_begin + (e.begin - all_begin);
|
||||
Elem e = elem_of_each_node_[node_id];
|
||||
|
||||
size_t* it = begin;
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it);
|
||||
it += row_split_tloc[tid].left.size();
|
||||
}
|
||||
size_t* split_pt = it;
|
||||
for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
|
||||
std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it);
|
||||
it += row_split_tloc[tid].right.size();
|
||||
}
|
||||
CHECK(e.begin != nullptr);
|
||||
|
||||
size_t* begin = const_cast<size_t*>(e.begin);
|
||||
size_t* split_pt = begin + iLeft;
|
||||
|
||||
if (left_node_id >= elem_of_each_node_.size()) {
|
||||
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
|
||||
elem_of_each_node_.resize((left_node_id + 1)*2, Elem(nullptr, nullptr, -1));
|
||||
}
|
||||
if (right_node_id >= elem_of_each_node_.size()) {
|
||||
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
|
||||
elem_of_each_node_.resize((right_node_id + 1)*2, Elem(nullptr, nullptr, -1));
|
||||
}
|
||||
|
||||
elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id);
|
||||
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id);
|
||||
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
|
||||
elem_of_each_node_[node_id] = Elem(begin, e.end, -1);
|
||||
}
|
||||
|
||||
// stores the row indices in the set
|
||||
|
||||
Reference in New Issue
Block a user