Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -1,25 +1,26 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include "iterative_dmatrix.h"
|
||||
|
||||
#include <algorithm> // std::copy
|
||||
#include <cstddef> // std::size_t
|
||||
#include <type_traits> // std::underlying_type_t
|
||||
#include <vector> // std::vector
|
||||
#include <algorithm> // for copy
|
||||
#include <cstddef> // for size_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <type_traits> // for underlying_type_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/categorical.h" // common::IsCat
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "batch_utils.h" // for RegenGHist
|
||||
#include "gradient_index.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
#include "xgboost/data.h" // FeatureType
|
||||
#include "xgboost/data.h" // for FeatureType, DMatrix
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
namespace xgboost::data {
|
||||
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
||||
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
|
||||
XGDMatrixCallbackNext* next, float missing, int nthread,
|
||||
@@ -34,60 +35,61 @@ IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle pro
|
||||
|
||||
auto d = MakeProxy(proxy_)->DeviceIdx();
|
||||
|
||||
StringView msg{"All batch should be on the same device."};
|
||||
if (batch_param_.gpu_id != Context::kCpuId) {
|
||||
CHECK_EQ(d, batch_param_.gpu_id) << msg;
|
||||
}
|
||||
|
||||
batch_param_ = BatchParam{d, max_bin};
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
|
||||
// hardcoded parameter.
|
||||
batch_param_.sparse_thresh = tree::TrainParam::DftSparseThreshold();
|
||||
BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()};
|
||||
|
||||
ctx_.UpdateAllowUnknown(
|
||||
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
|
||||
if (ctx_.IsCPU()) {
|
||||
this->InitFromCPU(iter_handle, missing, ref);
|
||||
if (ctx.IsCPU()) {
|
||||
this->InitFromCPU(&ctx, p, iter_handle, missing, ref);
|
||||
} else {
|
||||
this->InitFromCUDA(iter_handle, missing, ref);
|
||||
this->InitFromCUDA(&ctx, p, iter_handle, missing, ref);
|
||||
}
|
||||
|
||||
this->fmat_ctx_ = ctx;
|
||||
this->batch_ = p;
|
||||
}
|
||||
|
||||
void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
|
||||
common::HistogramCuts* p_cuts) {
|
||||
CHECK(ref_);
|
||||
void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_feature_t n_features,
|
||||
BatchParam p, common::HistogramCuts* p_cuts) {
|
||||
CHECK(ref);
|
||||
CHECK(p_cuts);
|
||||
auto csr = [&]() {
|
||||
for (auto const& page : ref_->GetBatches<GHistIndexMatrix>(p)) {
|
||||
p.forbid_regen = true;
|
||||
// Fetch cuts from GIDX
|
||||
auto csr = [&] {
|
||||
for (auto const& page : ref->GetBatches<GHistIndexMatrix>(ctx, p)) {
|
||||
*p_cuts = page.cut;
|
||||
break;
|
||||
}
|
||||
};
|
||||
auto ellpack = [&]() {
|
||||
// workaround ellpack being initialized from CPU.
|
||||
if (p.gpu_id == Context::kCpuId) {
|
||||
p.gpu_id = ref_->Ctx()->gpu_id;
|
||||
}
|
||||
if (p.gpu_id == Context::kCpuId) {
|
||||
p.gpu_id = 0;
|
||||
}
|
||||
for (auto const& page : ref_->GetBatches<EllpackPage>(p)) {
|
||||
// Fetch cuts from Ellpack.
|
||||
auto ellpack = [&] {
|
||||
for (auto const& page : ref->GetBatches<EllpackPage>(ctx, p)) {
|
||||
GetCutsFromEllpack(page, p_cuts);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if (ref_->PageExists<GHistIndexMatrix>()) {
|
||||
if (ref->PageExists<GHistIndexMatrix>() && ref->PageExists<EllpackPage>()) {
|
||||
// Both exists
|
||||
if (ctx->IsCPU()) {
|
||||
csr();
|
||||
} else {
|
||||
ellpack();
|
||||
}
|
||||
} else if (ref->PageExists<GHistIndexMatrix>()) {
|
||||
csr();
|
||||
} else if (ref_->PageExists<EllpackPage>()) {
|
||||
} else if (ref->PageExists<EllpackPage>()) {
|
||||
ellpack();
|
||||
} else {
|
||||
if (p.gpu_id == Context::kCpuId) {
|
||||
// None exist
|
||||
if (ctx->IsCPU()) {
|
||||
csr();
|
||||
} else {
|
||||
ellpack();
|
||||
}
|
||||
}
|
||||
CHECK_EQ(ref_->Info().num_col_, n_features)
|
||||
CHECK_EQ(ref->Info().num_col_, n_features)
|
||||
<< "Invalid ref DMatrix, different number of features.";
|
||||
}
|
||||
|
||||
@@ -112,7 +114,8 @@ void SyncFeatureType(std::vector<FeatureType>* p_h_ft) {
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
|
||||
DataIterHandle iter_handle, float missing,
|
||||
std::shared_ptr<DMatrix> ref) {
|
||||
DMatrixProxy* proxy = MakeProxy(proxy_);
|
||||
CHECK(proxy);
|
||||
@@ -133,7 +136,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
auto const is_valid = data::IsValidFunctor{missing};
|
||||
auto nnz_cnt = [&]() {
|
||||
return HostAdapterDispatch(proxy, [&](auto const& value) {
|
||||
size_t n_threads = ctx_.Threads();
|
||||
size_t n_threads = ctx->Threads();
|
||||
size_t n_features = column_sizes.size();
|
||||
linalg::Tensor<std::size_t, 2> column_sizes_tloc({n_threads, n_features}, Context::kCpuId);
|
||||
column_sizes_tloc.Data()->Fill(0ul);
|
||||
@@ -158,10 +161,10 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
});
|
||||
};
|
||||
|
||||
size_t n_features = 0;
|
||||
size_t n_batches = 0;
|
||||
size_t accumulated_rows{0};
|
||||
size_t nnz{0};
|
||||
std::uint64_t n_features = 0;
|
||||
std::size_t n_batches = 0;
|
||||
std::uint64_t accumulated_rows{0};
|
||||
std::uint64_t nnz{0};
|
||||
|
||||
/**
|
||||
* CPU impl needs an additional loop for accumulating the column size.
|
||||
@@ -203,7 +206,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
accumulated_rows = 0;
|
||||
std::vector<FeatureType> h_ft;
|
||||
if (ref) {
|
||||
GetCutsFromRef(ref, Info().num_col_, batch_param_, &cuts);
|
||||
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
|
||||
h_ft = ref->Info().feature_types.HostVector();
|
||||
} else {
|
||||
size_t i = 0;
|
||||
@@ -211,9 +214,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
if (!p_sketch) {
|
||||
h_ft = proxy->Info().feature_types.ConstHostVector();
|
||||
SyncFeatureType(&h_ft);
|
||||
p_sketch.reset(new common::HostSketchContainer{
|
||||
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
|
||||
ctx_.Threads()});
|
||||
p_sketch.reset(new common::HostSketchContainer{ctx, p.max_bin, h_ft, column_sizes,
|
||||
!proxy->Info().group_ptr_.empty()});
|
||||
}
|
||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||
@@ -237,15 +239,15 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
/**
|
||||
* Generate gradient index.
|
||||
*/
|
||||
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), batch_param_.max_bin);
|
||||
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), p.max_bin);
|
||||
size_t rbegin = 0;
|
||||
size_t prev_sum = 0;
|
||||
size_t i = 0;
|
||||
while (iter.Next()) {
|
||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||
this->ghist_->PushAdapterBatch(&ctx_, rbegin, prev_sum, batch, missing, h_ft,
|
||||
batch_param_.sparse_thresh, Info().num_row_);
|
||||
this->ghist_->PushAdapterBatch(ctx, rbegin, prev_sum, batch, missing, h_ft, p.sparse_thresh,
|
||||
Info().num_row_);
|
||||
});
|
||||
if (n_batches != 1) {
|
||||
this->info_.Extend(std::move(proxy->Info()), false, true);
|
||||
@@ -265,7 +267,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
accumulated_rows = 0;
|
||||
while (iter.Next()) {
|
||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||
this->ghist_->PushAdapterBatchColumns(&ctx_, batch, missing, accumulated_rows);
|
||||
this->ghist_->PushAdapterBatchColumns(ctx, batch, missing, accumulated_rows);
|
||||
});
|
||||
accumulated_rows += num_rows();
|
||||
}
|
||||
@@ -282,11 +284,27 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
Info().feature_types.HostVector() = h_ft;
|
||||
}
|
||||
|
||||
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
|
||||
CheckParam(param);
|
||||
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(Context const* ctx,
|
||||
BatchParam const& param) {
|
||||
if (param.Initialized()) {
|
||||
CheckParam(param);
|
||||
CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin();
|
||||
}
|
||||
if (!ellpack_ && !ghist_) {
|
||||
LOG(FATAL) << "`QuantileDMatrix` not initialized.";
|
||||
}
|
||||
|
||||
if (!ghist_) {
|
||||
CHECK(ellpack_);
|
||||
ghist_ = std::make_shared<GHistIndexMatrix>(&ctx_, Info(), *ellpack_, param);
|
||||
if (ctx->IsCPU()) {
|
||||
ghist_ = std::make_shared<GHistIndexMatrix>(ctx, Info(), *ellpack_, param);
|
||||
} else if (fmat_ctx_.IsCPU()) {
|
||||
ghist_ = std::make_shared<GHistIndexMatrix>(&fmat_ctx_, Info(), *ellpack_, param);
|
||||
} else {
|
||||
// Can happen when QDM is initialized on GPU, but a CPU version is queried by a different QDM
|
||||
// for cut reference.
|
||||
auto cpu_ctx = ctx->MakeCPU();
|
||||
ghist_ = std::make_shared<GHistIndexMatrix>(&cpu_ctx, Info(), *ellpack_, param);
|
||||
}
|
||||
}
|
||||
|
||||
if (!std::isnan(param.sparse_thresh) &&
|
||||
@@ -300,8 +318,9 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param) {
|
||||
for (auto const& page : this->GetGradientIndex(param)) {
|
||||
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(Context const* ctx,
|
||||
BatchParam const& param) {
|
||||
for (auto const& page : this->GetGradientIndex(ctx, param)) {
|
||||
auto p_out = std::make_shared<SparsePage>();
|
||||
p_out->data.Resize(this->Info().num_nonzero_);
|
||||
p_out->offset.Resize(this->Info().num_row_ + 1);
|
||||
@@ -336,5 +355,26 @@ BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param)
|
||||
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
|
||||
return BatchSet<ExtSparsePage>(begin_iter);
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline void IterativeDMatrix::InitFromCUDA(Context const*, BatchParam const&, DataIterHandle, float,
|
||||
std::shared_ptr<DMatrix>) {
|
||||
// silent the warning about unused variables.
|
||||
(void)(proxy_);
|
||||
(void)(reset_);
|
||||
(void)(next_);
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
|
||||
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
BatchParam const& param) {
|
||||
common::AssertGPUSupport();
|
||||
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
|
||||
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
|
||||
}
|
||||
|
||||
inline void GetCutsFromEllpack(EllpackPage const&, common::HistogramCuts*) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost::data
|
||||
|
||||
Reference in New Issue
Block a user