Implement GK sketching on GPU. (#5846)
* Implement GK sketching on GPU. * Strong tests on quantile building. * Handle sparse dataset by binary searching the column index. * Hypothesis test on dask.
This commit is contained in:
@@ -57,80 +57,52 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
size_t nnz = 0;
|
||||
// Sketch for all batches.
|
||||
iter.Reset();
|
||||
common::HistogramCuts cuts;
|
||||
common::DenseCuts dense_cuts(&cuts);
|
||||
|
||||
std::vector<common::SketchContainer> sketch_containers;
|
||||
size_t batches = 0;
|
||||
size_t accumulated_rows = 0;
|
||||
bst_feature_t cols = 0;
|
||||
int32_t device = -1;
|
||||
while (iter.Next()) {
|
||||
auto device = proxy->DeviceIdx();
|
||||
device = proxy->DeviceIdx();
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
if (cols == 0) {
|
||||
cols = num_cols();
|
||||
} else {
|
||||
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
||||
}
|
||||
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows());
|
||||
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device);
|
||||
auto* p_sketch = &sketch_containers.back();
|
||||
if (proxy->Info().weights_.Size() != 0) {
|
||||
proxy->Info().weights_.SetDevice(device);
|
||||
Dispatch(proxy, [&](auto const &value) {
|
||||
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
|
||||
proxy->Info(),
|
||||
missing, device, p_sketch);
|
||||
});
|
||||
} else {
|
||||
Dispatch(proxy, [&](auto const &value) {
|
||||
common::AdapterDeviceSketch(value, batch_param_.max_bin, missing,
|
||||
device, p_sketch);
|
||||
});
|
||||
}
|
||||
proxy->Info().weights_.SetDevice(device);
|
||||
Dispatch(proxy, [&](auto const &value) {
|
||||
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
|
||||
proxy->Info(), missing, p_sketch);
|
||||
});
|
||||
|
||||
auto batch_rows = num_rows();
|
||||
accumulated_rows += batch_rows;
|
||||
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(),
|
||||
row_counts.size());
|
||||
row_stride =
|
||||
std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
|
||||
return GetRowCounts(value, row_counts_span, device, missing);
|
||||
}));
|
||||
nnz += thrust::reduce(thrust::cuda::par(alloc),
|
||||
row_counts.begin(), row_counts.end());
|
||||
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) {
|
||||
return GetRowCounts(value, row_counts_span,
|
||||
device, missing);
|
||||
}));
|
||||
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(),
|
||||
row_counts.end());
|
||||
batches++;
|
||||
}
|
||||
|
||||
// Merging multiple batches for each column
|
||||
std::vector<common::WQSketch::SummaryContainer> summary_array(cols);
|
||||
size_t intermediate_num_cuts = std::min(
|
||||
accumulated_rows, static_cast<size_t>(batch_param_.max_bin *
|
||||
common::SketchContainer::kFactor));
|
||||
size_t nbytes =
|
||||
common::WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
|
||||
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
|
||||
for (omp_ulong c = 0; c < cols; ++c) {
|
||||
for (auto& sketch_batch : sketch_containers) {
|
||||
common::WQSketch::SummaryContainer summary;
|
||||
sketch_batch.sketches_.at(c).GetSummary(&summary);
|
||||
sketch_batch.sketches_.at(c).Init(0, 1);
|
||||
summary_array.at(c).Reduce(summary, nbytes);
|
||||
}
|
||||
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
|
||||
for (auto const& sketch : sketch_containers) {
|
||||
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
||||
final_sketch.FixError();
|
||||
}
|
||||
sketch_containers.clear();
|
||||
sketch_containers.shrink_to_fit();
|
||||
|
||||
// Build the final summary.
|
||||
std::vector<common::WQSketch> sketches(cols);
|
||||
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
|
||||
for (omp_ulong c = 0; c < cols; ++c) {
|
||||
sketches.at(c).Init(
|
||||
accumulated_rows,
|
||||
1.0 / (common::SketchContainer::kFactor * batch_param_.max_bin));
|
||||
sketches.at(c).PushSummary(summary_array.at(c));
|
||||
}
|
||||
dense_cuts.Init(&sketches, batch_param_.max_bin, accumulated_rows);
|
||||
summary_array.clear();
|
||||
common::HistogramCuts cuts;
|
||||
final_sketch.MakeCuts(&cuts);
|
||||
|
||||
this->info_.num_col_ = cols;
|
||||
this->info_.num_row_ = accumulated_rows;
|
||||
|
||||
Reference in New Issue
Block a user