Wide dataset quantile performance improvement (#5306)
This commit is contained in:
@@ -286,15 +286,28 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
}
|
||||
}
|
||||
|
||||
Init(&sketchs, max_num_bins);
|
||||
Init(&sketchs, max_num_bins, info.num_row_);
|
||||
monitor_.Stop(__FUNCTION__);
|
||||
}
|
||||
|
||||
/**
|
||||
* \param [in,out] in_sketchs
|
||||
* \param max_num_bins The maximum number bins.
|
||||
* \param max_rows Number of rows in this DMatrix.
|
||||
*/
|
||||
void DenseCuts::Init
|
||||
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins) {
|
||||
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins, size_t max_rows) {
|
||||
monitor_.Start(__func__);
|
||||
std::vector<WQSketch>& sketchs = *in_sketchs;
|
||||
|
||||
// Compute how many cuts samples we need at each node
|
||||
// Do not require more than the number of total rows in training data
|
||||
// This allows efficient training on wide data
|
||||
size_t global_max_rows = max_rows;
|
||||
rabit::Allreduce<rabit::op::Sum>(&global_max_rows, 1);
|
||||
constexpr int kFactor = 8;
|
||||
size_t intermediate_num_cuts =
|
||||
std::min(global_max_rows, static_cast<size_t>(max_num_bins * kFactor));
|
||||
// gather the histogram data
|
||||
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
|
||||
std::vector<WQSketch::SummaryContainer> summary_array;
|
||||
@@ -302,11 +315,11 @@ void DenseCuts::Init
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
WQSketch::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array[i].Reserve(max_num_bins * kFactor);
|
||||
summary_array[i].SetPrune(out, max_num_bins * kFactor);
|
||||
summary_array[i].Reserve(intermediate_num_cuts);
|
||||
summary_array[i].SetPrune(out, intermediate_num_cuts);
|
||||
}
|
||||
CHECK_EQ(summary_array.size(), in_sketchs->size());
|
||||
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
||||
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
|
||||
// TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint
|
||||
// we need to move this allreduce before loadcheckpoint call in future
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
|
||||
Reference in New Issue
Block a user