Use Booster context in DMatrix. (#8896)

- Pass context from booster to DMatrix.
- Use context instead of integer for `n_threads`.
- Check the consistency configuration for `max_bin`.
- Test for all combinations of initialization options.
This commit is contained in:
Jiaming Yuan
2023-04-28 21:47:14 +08:00
committed by GitHub
parent 1f9a57d17b
commit 08ce495b5d
67 changed files with 1283 additions and 935 deletions

View File

@@ -17,8 +17,8 @@ namespace xgboost {
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {}
EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param)
: impl_{new EllpackPageImpl(dmat, param)} {}
EllpackPage::EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param)
: impl_{new EllpackPageImpl{ctx, dmat, param}} {}
EllpackPage::~EllpackPage() = default;
@@ -105,29 +105,29 @@ EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
}
// Construct an ELLPACK matrix in memory.
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param)
EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& param)
: is_dense(dmat->IsDense()) {
monitor_.Init("ellpack_page");
dh::safe_cuda(cudaSetDevice(param.gpu_id));
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
n_rows = dmat->Info().num_row_;
monitor_.Start("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
row_stride = GetRowStride(dmat);
cuts_ = common::DeviceSketch(param.gpu_id, dmat, param.max_bin);
cuts_ = common::DeviceSketch(ctx->gpu_id, dmat, param.max_bin);
monitor_.Stop("Quantiles");
monitor_.Start("InitCompressedData");
this->InitCompressedData(param.gpu_id);
this->InitCompressedData(ctx->gpu_id);
monitor_.Stop("InitCompressedData");
dmat->Info().feature_types.SetDevice(param.gpu_id);
dmat->Info().feature_types.SetDevice(ctx->gpu_id);
auto ft = dmat->Info().feature_types.ConstDeviceSpan();
monitor_.Start("BinningCompression");
CHECK(dmat->SingleColBlock());
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
CreateHistIndices(param.gpu_id, batch, ft);
CreateHistIndices(ctx->gpu_id, batch, ft);
}
monitor_.Stop("BinningCompression");
}