Wide dataset quantile performance improvement (#5306)
This commit is contained in:
parent
ed2465cce4
commit
7e32af5c21
@ -286,15 +286,28 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Init(&sketchs, max_num_bins);
|
Init(&sketchs, max_num_bins, info.num_row_);
|
||||||
monitor_.Stop(__FUNCTION__);
|
monitor_.Stop(__FUNCTION__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \param [in,out] in_sketchs
|
||||||
|
* \param max_num_bins The maximum number bins.
|
||||||
|
* \param max_rows Number of rows in this DMatrix.
|
||||||
|
*/
|
||||||
void DenseCuts::Init
|
void DenseCuts::Init
|
||||||
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins) {
|
(std::vector<WQSketch>* in_sketchs, uint32_t max_num_bins, size_t max_rows) {
|
||||||
monitor_.Start(__func__);
|
monitor_.Start(__func__);
|
||||||
std::vector<WQSketch>& sketchs = *in_sketchs;
|
std::vector<WQSketch>& sketchs = *in_sketchs;
|
||||||
|
|
||||||
|
// Compute how many cuts samples we need at each node
|
||||||
|
// Do not require more than the number of total rows in training data
|
||||||
|
// This allows efficient training on wide data
|
||||||
|
size_t global_max_rows = max_rows;
|
||||||
|
rabit::Allreduce<rabit::op::Sum>(&global_max_rows, 1);
|
||||||
constexpr int kFactor = 8;
|
constexpr int kFactor = 8;
|
||||||
|
size_t intermediate_num_cuts =
|
||||||
|
std::min(global_max_rows, static_cast<size_t>(max_num_bins * kFactor));
|
||||||
// gather the histogram data
|
// gather the histogram data
|
||||||
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
|
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
|
||||||
std::vector<WQSketch::SummaryContainer> summary_array;
|
std::vector<WQSketch::SummaryContainer> summary_array;
|
||||||
@ -302,11 +315,11 @@ void DenseCuts::Init
|
|||||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||||
WQSketch::SummaryContainer out;
|
WQSketch::SummaryContainer out;
|
||||||
sketchs[i].GetSummary(&out);
|
sketchs[i].GetSummary(&out);
|
||||||
summary_array[i].Reserve(max_num_bins * kFactor);
|
summary_array[i].Reserve(intermediate_num_cuts);
|
||||||
summary_array[i].SetPrune(out, max_num_bins * kFactor);
|
summary_array[i].SetPrune(out, intermediate_num_cuts);
|
||||||
}
|
}
|
||||||
CHECK_EQ(summary_array.size(), in_sketchs->size());
|
CHECK_EQ(summary_array.size(), in_sketchs->size());
|
||||||
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
|
||||||
// TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint
|
// TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint
|
||||||
// we need to move this allreduce before loadcheckpoint call in future
|
// we need to move this allreduce before loadcheckpoint call in future
|
||||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||||
|
|||||||
@ -148,7 +148,7 @@ class GPUSketcher {
|
|||||||
this->SketchBatch(batch, info);
|
this->SketchBatch(batch, info);
|
||||||
}
|
}
|
||||||
|
|
||||||
hmat->Init(&sketch_container_->sketches_, max_bin_);
|
hmat->Init(&sketch_container_->sketches_, max_bin_, info.num_row_);
|
||||||
return row_stride_;
|
return row_stride_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -244,7 +244,7 @@ class DenseCuts : public CutsBuilder {
|
|||||||
CutsBuilder(container) {
|
CutsBuilder(container) {
|
||||||
monitor_.Init(__FUNCTION__);
|
monitor_.Init(__FUNCTION__);
|
||||||
}
|
}
|
||||||
void Init(std::vector<WQSketch>* sketchs, uint32_t max_num_bins);
|
void Init(std::vector<WQSketch>* sketchs, uint32_t max_num_bins, size_t max_rows);
|
||||||
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
|
void Build(DMatrix* p_fmat, uint32_t max_num_bins) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -702,6 +702,7 @@ class QuantileSketchTemplate {
|
|||||||
nlevel = 1;
|
nlevel = 1;
|
||||||
while (true) {
|
while (true) {
|
||||||
limit_size = static_cast<size_t>(ceil(nlevel / eps)) + 1;
|
limit_size = static_cast<size_t>(ceil(nlevel / eps)) + 1;
|
||||||
|
limit_size = std::min(maxn, limit_size);
|
||||||
size_t n = (1ULL << nlevel);
|
size_t n = (1ULL << nlevel);
|
||||||
if (n * limit_size >= maxn) break;
|
if (n * limit_size >= maxn) break;
|
||||||
++nlevel;
|
++nlevel;
|
||||||
@ -709,7 +710,8 @@ class QuantileSketchTemplate {
|
|||||||
// check invariant
|
// check invariant
|
||||||
size_t n = (1ULL << nlevel);
|
size_t n = (1ULL << nlevel);
|
||||||
CHECK(n * limit_size >= maxn) << "invalid init parameter";
|
CHECK(n * limit_size >= maxn) << "invalid init parameter";
|
||||||
CHECK(nlevel <= limit_size * eps) << "invalid init parameter";
|
CHECK(nlevel <= std::max(1, static_cast<int>(limit_size * eps)))
|
||||||
|
<< "invalid init parameter";
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user