Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -1,22 +1,24 @@
|
||||
/*!
|
||||
* Copyright 2020-2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2020-2023, XGBoost contributors
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "../common/hist_util.cuh"
|
||||
#include "batch_utils.h" // for RegenGHist
|
||||
#include "device_adapter.cuh"
|
||||
#include "ellpack_page.cuh"
|
||||
#include "gradient_index.h"
|
||||
#include "iterative_dmatrix.h"
|
||||
#include "proxy_dmatrix.cuh"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
#include "sparse_page_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
||||
namespace xgboost::data {
|
||||
void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
DataIterHandle iter_handle, float missing,
|
||||
std::shared_ptr<DMatrix> ref) {
|
||||
// A handle passed to external iterator.
|
||||
DMatrixProxy* proxy = MakeProxy(proxy_);
|
||||
@@ -46,7 +48,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
||||
int32_t current_device;
|
||||
dh::safe_cuda(cudaGetDevice(¤t_device));
|
||||
auto get_device = [&]() -> int32_t {
|
||||
int32_t d = (ctx_.gpu_id == Context::kCpuId) ? current_device : ctx_.gpu_id;
|
||||
std::int32_t d = (ctx->gpu_id == Context::kCpuId) ? current_device : ctx->gpu_id;
|
||||
CHECK_NE(d, Context::kCpuId);
|
||||
return d;
|
||||
};
|
||||
@@ -57,8 +59,8 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
||||
common::HistogramCuts cuts;
|
||||
do {
|
||||
// We use do while here as the first batch is fetched in ctor
|
||||
ctx_.gpu_id = proxy->DeviceIdx();
|
||||
CHECK_LT(ctx_.gpu_id, common::AllVisibleGPUs());
|
||||
// ctx_.gpu_id = proxy->DeviceIdx();
|
||||
CHECK_LT(ctx->gpu_id, common::AllVisibleGPUs());
|
||||
dh::safe_cuda(cudaSetDevice(get_device()));
|
||||
if (cols == 0) {
|
||||
cols = num_cols();
|
||||
@@ -68,12 +70,12 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
||||
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
|
||||
}
|
||||
if (!ref) {
|
||||
sketch_containers.emplace_back(proxy->Info().feature_types, batch_param_.max_bin, cols,
|
||||
num_rows(), get_device());
|
||||
sketch_containers.emplace_back(proxy->Info().feature_types, p.max_bin, cols, num_rows(),
|
||||
get_device());
|
||||
auto* p_sketch = &sketch_containers.back();
|
||||
proxy->Info().weights_.SetDevice(get_device());
|
||||
Dispatch(proxy, [&](auto const& value) {
|
||||
common::AdapterDeviceSketch(value, batch_param_.max_bin, proxy->Info(), missing, p_sketch);
|
||||
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch);
|
||||
});
|
||||
}
|
||||
auto batch_rows = num_rows();
|
||||
@@ -95,8 +97,8 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
||||
if (!ref) {
|
||||
HostDeviceVector<FeatureType> ft;
|
||||
common::SketchContainer final_sketch(
|
||||
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(),
|
||||
batch_param_.max_bin, cols, accumulated_rows, get_device());
|
||||
sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), p.max_bin, cols,
|
||||
accumulated_rows, get_device());
|
||||
for (auto const& sketch : sketch_containers) {
|
||||
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
||||
final_sketch.FixError();
|
||||
@@ -106,7 +108,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
||||
|
||||
final_sketch.MakeCuts(&cuts);
|
||||
} else {
|
||||
GetCutsFromRef(ref, Info().num_col_, batch_param_, &cuts);
|
||||
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
|
||||
}
|
||||
|
||||
this->info_.num_row_ = accumulated_rows;
|
||||
@@ -169,24 +171,34 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
|
||||
info_.SynchronizeNumberOfColumns();
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {
|
||||
CheckParam(param);
|
||||
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
BatchParam const& param) {
|
||||
if (param.Initialized()) {
|
||||
CheckParam(param);
|
||||
CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin();
|
||||
}
|
||||
if (!ellpack_ && !ghist_) {
|
||||
LOG(FATAL) << "`QuantileDMatrix` not initialized.";
|
||||
}
|
||||
if (!ellpack_ && ghist_) {
|
||||
|
||||
if (!ellpack_) {
|
||||
ellpack_.reset(new EllpackPage());
|
||||
// Evaluation QuantileDMatrix initialized from CPU data might not have the correct GPU
|
||||
// ID.
|
||||
if (this->ctx_.IsCPU()) {
|
||||
this->ctx_.gpu_id = param.gpu_id;
|
||||
if (ctx->IsCUDA()) {
|
||||
this->Info().feature_types.SetDevice(ctx->gpu_id);
|
||||
*ellpack_->Impl() =
|
||||
EllpackPageImpl(ctx, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||
} else if (fmat_ctx_.IsCUDA()) {
|
||||
this->Info().feature_types.SetDevice(fmat_ctx_.gpu_id);
|
||||
*ellpack_->Impl() =
|
||||
EllpackPageImpl(&fmat_ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||
} else {
|
||||
// Can happen when QDM is initialized on CPU, but a GPU version is queried by a different QDM
|
||||
// for cut reference.
|
||||
auto cuda_ctx = ctx->MakeCUDA();
|
||||
this->Info().feature_types.SetDevice(cuda_ctx.gpu_id);
|
||||
*ellpack_->Impl() =
|
||||
EllpackPageImpl(&cuda_ctx, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||
}
|
||||
if (this->ctx_.IsCPU()) {
|
||||
this->ctx_.gpu_id = dh::CurrentDevice();
|
||||
}
|
||||
this->Info().feature_types.SetDevice(this->ctx_.gpu_id);
|
||||
*ellpack_->Impl() =
|
||||
EllpackPageImpl(&ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||
}
|
||||
CHECK(ellpack_);
|
||||
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
|
||||
@@ -196,5 +208,4 @@ BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& para
|
||||
void GetCutsFromEllpack(EllpackPage const& page, common::HistogramCuts* cuts) {
|
||||
*cuts = page.Impl()->Cuts();
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
|
||||
Reference in New Issue
Block a user