Sketching from adapters (#5365)

* Sketching from adapters

* Add weights test
This commit is contained in:
Rory Mitchell
2020-03-07 21:07:58 +13:00
committed by GitHub
parent 0dd97c206b
commit a38e7bd19c
11 changed files with 780 additions and 624 deletions

View File

@@ -78,6 +78,20 @@ EllpackPageImpl::EllpackPageImpl(int device, EllpackInfo info, size_t n_rows) {
monitor_.StopCuda("InitCompressedData");
}
size_t GetRowStride(DMatrix* dmat) {
if (dmat->IsDense()) return dmat->Info().num_col_;
size_t row_stride = 0;
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
const auto& row_offset = batch.offset.ConstHostVector();
for (auto i = 1ull; i < row_offset.size(); i++) {
row_stride = std::max(
row_stride, static_cast<size_t>(row_offset[i] - row_offset[i - 1]));
}
}
return row_stride;
}
// Construct an ELLPACK matrix in memory.
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
monitor_.Init("ellpack_page");
@@ -87,13 +101,13 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) {
monitor_.StartCuda("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
common::HistogramCuts hmat;
size_t row_stride =
common::DeviceSketch(param.gpu_id, param.max_bin, param.gpu_batch_nrows, dmat, &hmat);
size_t row_stride = GetRowStride(dmat);
auto cuts = common::DeviceSketch(param.gpu_id, dmat, param.max_bin,
param.gpu_batch_nrows);
monitor_.StopCuda("Quantiles");
monitor_.StartCuda("InitEllpackInfo");
InitInfo(param.gpu_id, dmat->IsDense(), row_stride, hmat);
InitInfo(param.gpu_id, dmat->IsDense(), row_stride, cuts);
monitor_.StopCuda("InitEllpackInfo");
monitor_.StartCuda("InitCompressedData");