Optimize cpu sketch allreduce for sparse data. (#6009)
* Bypass RABIT serialization reducer and use custom allgather based merging.
This commit is contained in:
@@ -166,6 +166,16 @@ struct WQSummary {
|
||||
* \param src source sketch
|
||||
*/
|
||||
inline void CopyFrom(const WQSummary &src) {
|
||||
if (!src.data) {
|
||||
CHECK_EQ(src.size, 0);
|
||||
size = 0;
|
||||
return;
|
||||
}
|
||||
if (!data) {
|
||||
CHECK_EQ(this->size, 0);
|
||||
CHECK_EQ(src.size, 0);
|
||||
return;
|
||||
}
|
||||
size = src.size;
|
||||
std::memcpy(data, src.data, sizeof(Entry) * size);
|
||||
}
|
||||
@@ -721,6 +731,14 @@ class HostSketchContainer {
|
||||
return use_group_ind;
|
||||
}
|
||||
|
||||
static std::vector<bst_row_t> CalcColumnSize(SparsePage const &page,
|
||||
bst_feature_t const n_columns,
|
||||
size_t const nthreads);
|
||||
|
||||
static std::vector<bst_feature_t> LoadBalance(SparsePage const &page,
|
||||
bst_feature_t n_columns,
|
||||
size_t const nthreads);
|
||||
|
||||
static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
|
||||
size_t const base_rowid) {
|
||||
CHECK_LT(base_rowid, group_ptr.back())
|
||||
@@ -730,6 +748,14 @@ class HostSketchContainer {
|
||||
group_ptr.cbegin() - 1;
|
||||
return group_ind;
|
||||
}
|
||||
// Gather sketches from all workers.
|
||||
void GatherSketchInfo(std::vector<WQSketch::SummaryContainer> const &reduced,
|
||||
std::vector<bst_row_t> *p_worker_segments,
|
||||
std::vector<bst_row_t> *p_sketches_scan,
|
||||
std::vector<WQSketch::Entry> *p_global_sketches);
|
||||
// Merge sketches from all workers.
|
||||
void AllReduce(std::vector<WQSketch::SummaryContainer> *p_reduced,
|
||||
std::vector<int32_t>* p_num_cuts);
|
||||
|
||||
/* \brief Push a CSR matrix. */
|
||||
void PushRowPage(SparsePage const& page, MetaInfo const& info);
|
||||
|
||||
Reference in New Issue
Block a user