Implement GK sketching on GPU. (#5846)

* Implement GK sketching on GPU.
* Strong tests on quantile building.
* Handle sparse dataset by binary searching the column index.
* Hypothesis test on dask.
This commit is contained in:
Jiaming Yuan
2020-07-07 12:16:21 +08:00
committed by GitHub
parent ac3f0e78dc
commit 048d969be4
25 changed files with 2045 additions and 405 deletions

View File

@@ -57,80 +57,52 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
size_t nnz = 0;
// Sketch for all batches.
iter.Reset();
common::HistogramCuts cuts;
common::DenseCuts dense_cuts(&cuts);
std::vector<common::SketchContainer> sketch_containers;
size_t batches = 0;
size_t accumulated_rows = 0;
bst_feature_t cols = 0;
int32_t device = -1;
while (iter.Next()) {
auto device = proxy->DeviceIdx();
device = proxy->DeviceIdx();
dh::safe_cuda(cudaSetDevice(device));
if (cols == 0) {
cols = num_cols();
} else {
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
}
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows());
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device);
auto* p_sketch = &sketch_containers.back();
if (proxy->Info().weights_.Size() != 0) {
proxy->Info().weights_.SetDevice(device);
Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
proxy->Info(),
missing, device, p_sketch);
});
} else {
Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketch(value, batch_param_.max_bin, missing,
device, p_sketch);
});
}
proxy->Info().weights_.SetDevice(device);
Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
proxy->Info(), missing, p_sketch);
});
auto batch_rows = num_rows();
accumulated_rows += batch_rows;
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(),
row_counts.size());
row_stride =
std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, device, missing);
}));
nnz += thrust::reduce(thrust::cuda::par(alloc),
row_counts.begin(), row_counts.end());
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) {
return GetRowCounts(value, row_counts_span,
device, missing);
}));
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(),
row_counts.end());
batches++;
}
// Merging multiple batches for each column
std::vector<common::WQSketch::SummaryContainer> summary_array(cols);
size_t intermediate_num_cuts = std::min(
accumulated_rows, static_cast<size_t>(batch_param_.max_bin *
common::SketchContainer::kFactor));
size_t nbytes =
common::WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
for (omp_ulong c = 0; c < cols; ++c) {
for (auto& sketch_batch : sketch_containers) {
common::WQSketch::SummaryContainer summary;
sketch_batch.sketches_.at(c).GetSummary(&summary);
sketch_batch.sketches_.at(c).Init(0, 1);
summary_array.at(c).Reduce(summary, nbytes);
}
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
for (auto const& sketch : sketch_containers) {
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
final_sketch.FixError();
}
sketch_containers.clear();
sketch_containers.shrink_to_fit();
// Build the final summary.
std::vector<common::WQSketch> sketches(cols);
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
for (omp_ulong c = 0; c < cols; ++c) {
sketches.at(c).Init(
accumulated_rows,
1.0 / (common::SketchContainer::kFactor * batch_param_.max_bin));
sketches.at(c).PushSummary(summary_array.at(c));
}
dense_cuts.Init(&sketches, batch_param_.max_bin, accumulated_rows);
summary_array.clear();
common::HistogramCuts cuts;
final_sketch.MakeCuts(&cuts);
this->info_.num_col_ = cols;
this->info_.num_row_ = accumulated_rows;