Wide dataset quantile performance improvement (#5306)
This commit is contained in:
parent
ed2465cce4
commit
7e32af5c21
@ -286,15 +286,28 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
}
|
||||
}
|
||||
|
||||
Init(&sketchs, max_num_bins);
|
||||
Init(&sketchs, max_num_bins, info.num_row_);
|
||||
monitor_.Stop(__FUNCTION__);
|
||||
}
|
||||
|
||||
/**
|
||||
* \param [in,out] in_sketchs
|
||||
* \param max_num_bins The maximum number bins.
|
||||
* \param max_rows Number of rows in this DMatrix.
|
||||
*/
|
||||
void DenseCuts::Init
|
||||
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins) {
|
||||
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins, size_t max_rows) {
|
||||
monitor_.Start(__func__);
|
||||
std::vector<WQSketch>& sketchs = *in_sketchs;
|
||||
|
||||
// Compute how many cuts samples we need at each node
|
||||
// Do not require more than the number of total rows in training data
|
||||
// This allows efficient training on wide data
|
||||
size_t global_max_rows = max_rows;
|
||||
rabit::Allreduce<rabit::op::Sum>(&global_max_rows, 1);
|
||||
constexpr int kFactor = 8;
|
||||
size_t intermediate_num_cuts =
|
||||
std::min(global_max_rows, static_cast<size_t>(max_num_bins * kFactor));
|
||||
// gather the histogram data
|
||||
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
|
||||
std::vector<WQSketch::SummaryContainer> summary_array;
|
||||
@ -302,11 +315,11 @@ void DenseCuts::Init
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
WQSketch::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array[i].Reserve(max_num_bins * kFactor);
|
||||
summary_array[i].SetPrune(out, max_num_bins * kFactor);
|
||||
summary_array[i].Reserve(intermediate_num_cuts);
|
||||
summary_array[i].SetPrune(out, intermediate_num_cuts);
|
||||
}
|
||||
CHECK_EQ(summary_array.size(), in_sketchs->size());
|
||||
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
||||
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
|
||||
// TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint
|
||||
// we need to move this allreduce before loadcheckpoint call in future
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
|
||||
@ -148,7 +148,7 @@ class GPUSketcher {
|
||||
this->SketchBatch(batch, info);
|
||||
}
|
||||
|
||||
hmat->Init(&sketch_container_->sketches_, max_bin_);
|
||||
hmat->Init(&sketch_container_->sketches_, max_bin_, info.num_row_);
|
||||
return row_stride_;
|
||||
}
|
||||
|
||||
|
||||
@ -244,7 +244,7 @@ class DenseCuts : public CutsBuilder {
|
||||
CutsBuilder(container) {
|
||||
monitor_.Init(__FUNCTION__);
|
||||
}
|
||||
void Init(std::vector<WQSketch>* sketchs, uint32_t max_num_bins);
|
||||
void Init(std::vector<WQSketch>* sketchs, uint32_t max_num_bins, size_t max_rows);
|
||||
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
|
||||
};
|
||||
|
||||
|
||||
@ -702,6 +702,7 @@ class QuantileSketchTemplate {
|
||||
nlevel = 1;
|
||||
while (true) {
|
||||
limit_size = static_cast<size_t>(ceil(nlevel / eps)) + 1;
|
||||
limit_size = std::min(maxn, limit_size);
|
||||
size_t n = (1ULL << nlevel);
|
||||
if (n * limit_size >= maxn) break;
|
||||
++nlevel;
|
||||
@ -709,7 +710,8 @@ class QuantileSketchTemplate {
|
||||
// check invariant
|
||||
size_t n = (1ULL << nlevel);
|
||||
CHECK(n * limit_size >= maxn) << "invalid init parameter";
|
||||
CHECK(nlevel <= limit_size * eps) << "invalid init parameter";
|
||||
CHECK(nlevel <= std::max(1, static_cast<int>(limit_size * eps)))
|
||||
<< "invalid init parameter";
|
||||
}
|
||||
|
||||
/*!
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user