merge 23Mar01

This commit is contained in:
amdsc21
2023-05-02 00:05:58 +02:00
258 changed files with 7471 additions and 5379 deletions

View File

@@ -21,8 +21,8 @@ namespace xgboost {
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl()} {}
EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param)
: impl_{new EllpackPageImpl(dmat, param)} {}
EllpackPage::EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param)
: impl_{new EllpackPageImpl{ctx, dmat, param}} {}
EllpackPage::~EllpackPage() = default;
@@ -114,14 +114,13 @@ EllpackPageImpl::EllpackPageImpl(int device, common::HistogramCuts cuts,
}
// Construct an ELLPACK matrix in memory.
EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param)
EllpackPageImpl::EllpackPageImpl(Context const* ctx, DMatrix* dmat, const BatchParam& param)
: is_dense(dmat->IsDense()) {
monitor_.Init("ellpack_page");
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(param.gpu_id));
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(param.gpu_id));
dh::safe_cuda(hipSetDevice(ctx->gpu_id));
#endif
n_rows = dmat->Info().num_row_;
@@ -129,19 +128,19 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param)
monitor_.Start("Quantiles");
// Create the quantile sketches for the dmatrix and initialize HistogramCuts.
row_stride = GetRowStride(dmat);
cuts_ = common::DeviceSketch(param.gpu_id, dmat, param.max_bin);
cuts_ = common::DeviceSketch(ctx->gpu_id, dmat, param.max_bin);
monitor_.Stop("Quantiles");
monitor_.Start("InitCompressedData");
this->InitCompressedData(param.gpu_id);
this->InitCompressedData(ctx->gpu_id);
monitor_.Stop("InitCompressedData");
dmat->Info().feature_types.SetDevice(param.gpu_id);
dmat->Info().feature_types.SetDevice(ctx->gpu_id);
auto ft = dmat->Info().feature_types.ConstDeviceSpan();
monitor_.Start("BinningCompression");
CHECK(dmat->SingleColBlock());
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
CreateHistIndices(param.gpu_id, batch, ft);
CreateHistIndices(ctx->gpu_id, batch, ft);
}
monitor_.Stop("BinningCompression");
}