Wide dataset quantile performance improvement (#5306)

This commit is contained in:
Rory Mitchell 2020-02-16 10:24:42 +13:00 committed by GitHub
parent ed2465cce4
commit 7e32af5c21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 8 deletions

View File

@ -286,15 +286,28 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
}
}
Init(&sketchs, max_num_bins);
Init(&sketchs, max_num_bins, info.num_row_);
monitor_.Stop(__FUNCTION__);
}
/**
* \param [in,out] in_sketchs
* \param max_num_bins The maximum number bins.
* \param max_rows Number of rows in this DMatrix.
*/
void DenseCuts::Init
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins) {
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins, size_t max_rows) {
monitor_.Start(__func__);
std::vector<WQSketch>& sketchs = *in_sketchs;
// Compute how many cuts samples we need at each node
// Do not require more than the number of total rows in training data
// This allows efficient training on wide data
size_t global_max_rows = max_rows;
rabit::Allreduce<rabit::op::Sum>(&global_max_rows, 1);
constexpr int kFactor = 8;
size_t intermediate_num_cuts =
std::min(global_max_rows, static_cast<size_t>(max_num_bins * kFactor));
// gather the histogram data
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
std::vector<WQSketch::SummaryContainer> summary_array;
@ -302,11 +315,11 @@ void DenseCuts::Init
for (size_t i = 0; i < sketchs.size(); ++i) {
WQSketch::SummaryContainer out;
sketchs[i].GetSummary(&out);
summary_array[i].Reserve(max_num_bins * kFactor);
summary_array[i].SetPrune(out, max_num_bins * kFactor);
summary_array[i].Reserve(intermediate_num_cuts);
summary_array[i].SetPrune(out, intermediate_num_cuts);
}
CHECK_EQ(summary_array.size(), in_sketchs->size());
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
// TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint
// we need to move this allreduce before loadcheckpoint call in future
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());

View File

@ -148,7 +148,7 @@ class GPUSketcher {
this->SketchBatch(batch, info);
}
hmat->Init(&sketch_container_->sketches_, max_bin_);
hmat->Init(&sketch_container_->sketches_, max_bin_, info.num_row_);
return row_stride_;
}

View File

@ -244,7 +244,7 @@ class DenseCuts : public CutsBuilder {
CutsBuilder(container) {
monitor_.Init(__FUNCTION__);
}
void Init(std::vector<WQSketch>* sketchs, uint32_t max_num_bins);
void Init(std::vector<WQSketch>* sketchs, uint32_t max_num_bins, size_t max_rows);
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
};

View File

@ -702,6 +702,7 @@ class QuantileSketchTemplate {
nlevel = 1;
while (true) {
limit_size = static_cast<size_t>(ceil(nlevel / eps)) + 1;
limit_size = std::min(maxn, limit_size);
size_t n = (1ULL << nlevel);
if (n * limit_size >= maxn) break;
++nlevel;
@ -709,7 +710,8 @@ class QuantileSketchTemplate {
// check invariant
size_t n = (1ULL << nlevel);
CHECK(n * limit_size >= maxn) << "invalid init parameter";
CHECK(nlevel <= limit_size * eps) << "invalid init parameter";
CHECK(nlevel <= std::max(1, static_cast<int>(limit_size * eps)))
<< "invalid init parameter";
}
/*!