Improve multi-threaded performance (#2104)

* Add UpdatePredictionCache() option to updaters

Some updaters (e.g. fast_hist) has enough information to quickly compute
prediction cache for the training data. Each updater may override
UpdaterPredictionCache() method to update the prediction cache. Note: this
trick does not apply to validation data.

* Respond to code review

* Disable some debug messages by default
* Document UpdatePredictionCache() interface
* Remove base_margin logic from UpdatePredictionCache() implementation
* Do not take pointer to cfg, as reference may get stale

* Improve multi-threaded performance

* Use columnwise accessor to accelerate ApplySplit() step,
  with support for a compressed representation
* Parallel sort for evaluation step
* Inline BuildHist() function
* Cache gradient pairs when building histograms in BuildHist()

* Add missing #if macro

* Respond to code review

* Use wrapper to enable parallel sort on Linux

* Fix C++ compatibility issues

* MSVC doesn't support unsigned in OpenMP loops
* gcc 4.6 doesn't support using keyword

* Fix lint issues

* Respond to code review

* Fix bug in ApplySplitSparseData()

* Attempting to read beyond the end of a sparse column
* Mishandling the case where an entire range of rows have missing values

* Fix training continuation bug

Disable UpdatePredictionCache() in the first iteration. This way, we can
accomodate the scenario where we build off of an existing (nonempty) ensemble.

* Add regression test for fast_hist

* Respond to code review

* Add back old version of ApplySplitSparseData
This commit is contained in:
Philip Cho
2017-03-25 10:35:01 -07:00
committed by Tianqi Chen
parent 332aea26a3
commit 14fba01b5a
14 changed files with 719 additions and 171 deletions

View File

@@ -17,15 +17,20 @@ namespace common {
/*! \brief collection of rowset */
class RowSetCollection {
public:
/*! \brief subset of rows */
/*! \brief data structure to store an instance set, a subset of
* rows (instances) associated with a particular node in a decision
* tree. */
struct Elem {
const bst_uint* begin;
const bst_uint* end;
int node_id;
// id of node associated with this instance set; -1 means uninitialized
Elem(void)
: begin(nullptr), end(nullptr) {}
: begin(nullptr), end(nullptr), node_id(-1) {}
Elem(const bst_uint* begin,
const bst_uint* end)
: begin(begin), end(end) {}
const bst_uint* end,
int node_id)
: begin(begin), end(end), node_id(node_id) {}
inline size_t size() const {
return end - begin;
@@ -36,6 +41,15 @@ class RowSetCollection {
std::vector<bst_uint> left;
std::vector<bst_uint> right;
};
inline std::vector<Elem>::const_iterator begin() const {
return elem_of_each_node_.begin();
}
inline std::vector<Elem>::const_iterator end() const {
return elem_of_each_node_.end();
}
/*! \brief return corresponding element set given the node_id */
inline const Elem& operator[](unsigned node_id) const {
const Elem& e = elem_of_each_node_[node_id];
@@ -53,7 +67,7 @@ class RowSetCollection {
CHECK_EQ(elem_of_each_node_.size(), 0U);
const bst_uint* begin = dmlc::BeginPtr(row_indices_);
const bst_uint* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
elem_of_each_node_.emplace_back(Elem(begin, end));
elem_of_each_node_.emplace_back(Elem(begin, end, 0));
}
// split rowset into two
inline void AddSplit(unsigned node_id,
@@ -79,15 +93,15 @@ class RowSetCollection {
}
if (left_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr));
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
}
if (right_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr));
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
}
elem_of_each_node_[left_node_id] = Elem(begin, split_pt);
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end);
elem_of_each_node_[node_id] = Elem(nullptr, nullptr);
elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id);
elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id);
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
}
// stores the row indices in the set