Implement GK sketching on GPU. (#5846)
* Implement GK sketching on GPU. * Strong tests on quantile building. * Handle sparse dataset by binary searching the column index. * Hypothesis test on dask.
This commit is contained in:
@@ -158,7 +158,6 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
|
||||
uint32_t beg_col, uint32_t end_col,
|
||||
uint32_t thread_id) {
|
||||
CHECK_GE(end_col, beg_col);
|
||||
constexpr float kFactor = 8;
|
||||
|
||||
// Data groups, used in ranking.
|
||||
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
|
||||
@@ -175,11 +174,12 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
|
||||
max_num_bins);
|
||||
if (n_bins == 0) {
|
||||
// cut_ptrs_ is initialized with a zero, so there's always an element at the back
|
||||
CHECK_GE(local_ptrs.size(), 1);
|
||||
local_ptrs.emplace_back(local_ptrs.back());
|
||||
continue;
|
||||
}
|
||||
|
||||
sketch.Init(info.num_row_, 1.0 / (n_bins * kFactor));
|
||||
sketch.Init(info.num_row_, 1.0 / (n_bins * WQSketch::kFactor));
|
||||
for (auto const& entry : column) {
|
||||
uint32_t weight_ind = 0;
|
||||
if (use_group_ind) {
|
||||
@@ -329,7 +329,6 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
|
||||
// safe factor for better accuracy
|
||||
constexpr int kFactor = 8;
|
||||
std::vector<WQSketch> sketchs;
|
||||
|
||||
const int nthread = omp_get_max_threads();
|
||||
@@ -339,7 +338,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
|
||||
unsigned const ncol = static_cast<unsigned>(info.num_col_);
|
||||
sketchs.resize(info.num_col_);
|
||||
for (auto& s : sketchs) {
|
||||
s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor));
|
||||
s.Init(info.num_row_, 1.0 / (max_num_bins * WQSketch::kFactor));
|
||||
}
|
||||
|
||||
// Data groups, used in ranking.
|
||||
@@ -410,9 +409,8 @@ void DenseCuts::Init
|
||||
// This allows efficient training on wide data
|
||||
size_t global_max_rows = max_rows;
|
||||
rabit::Allreduce<rabit::op::Sum>(&global_max_rows, 1);
|
||||
constexpr int kFactor = 8;
|
||||
size_t intermediate_num_cuts =
|
||||
std::min(global_max_rows, static_cast<size_t>(max_num_bins * kFactor));
|
||||
std::min(global_max_rows, static_cast<size_t>(max_num_bins * WQSketch::kFactor));
|
||||
// gather the histogram data
|
||||
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
|
||||
std::vector<WQSketch::SummaryContainer> summary_array;
|
||||
|
||||
Reference in New Issue
Block a user