Use Booster context in DMatrix. (#8896)

- Pass context from booster to DMatrix.
- Use context instead of integer for `n_threads`.
- Check the consistency configuration for `max_bin`.
- Test for all combinations of initialization options.
This commit is contained in:
Jiaming Yuan
2023-04-28 21:47:14 +08:00
committed by GitHub
parent 1f9a57d17b
commit 08ce495b5d
67 changed files with 1283 additions and 935 deletions

View File

@@ -1,25 +1,26 @@
/*!
* Copyright 2022 XGBoost contributors
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include "iterative_dmatrix.h"
#include <algorithm> // std::copy
#include <cstddef> // std::size_t
#include <type_traits> // std::underlying_type_t
#include <vector> // std::vector
#include <algorithm> // for copy
#include <cstddef> // for size_t
#include <memory> // for shared_ptr
#include <type_traits> // for underlying_type_t
#include <vector> // for vector
#include "../collective/communicator-inl.h"
#include "../common/categorical.h" // common::IsCat
#include "../common/column_matrix.h"
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "batch_utils.h" // for RegenGHist
#include "gradient_index.h"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
#include "xgboost/data.h" // FeatureType
#include "xgboost/data.h" // for FeatureType, DMatrix
#include "xgboost/logging.h"
namespace xgboost {
namespace data {
namespace xgboost::data {
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int nthread,
@@ -34,60 +35,61 @@ IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle pro
auto d = MakeProxy(proxy_)->DeviceIdx();
StringView msg{"All batch should be on the same device."};
if (batch_param_.gpu_id != Context::kCpuId) {
CHECK_EQ(d, batch_param_.gpu_id) << msg;
}
batch_param_ = BatchParam{d, max_bin};
Context ctx;
ctx.UpdateAllowUnknown(Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
// hardcoded parameter.
batch_param_.sparse_thresh = tree::TrainParam::DftSparseThreshold();
BatchParam p{max_bin, tree::TrainParam::DftSparseThreshold()};
ctx_.UpdateAllowUnknown(
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
if (ctx_.IsCPU()) {
this->InitFromCPU(iter_handle, missing, ref);
if (ctx.IsCPU()) {
this->InitFromCPU(&ctx, p, iter_handle, missing, ref);
} else {
this->InitFromCUDA(iter_handle, missing, ref);
this->InitFromCUDA(&ctx, p, iter_handle, missing, ref);
}
this->fmat_ctx_ = ctx;
this->batch_ = p;
}
void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
common::HistogramCuts* p_cuts) {
CHECK(ref_);
void GetCutsFromRef(Context const* ctx, std::shared_ptr<DMatrix> ref, bst_feature_t n_features,
BatchParam p, common::HistogramCuts* p_cuts) {
CHECK(ref);
CHECK(p_cuts);
auto csr = [&]() {
for (auto const& page : ref_->GetBatches<GHistIndexMatrix>(p)) {
p.forbid_regen = true;
// Fetch cuts from GIDX
auto csr = [&] {
for (auto const& page : ref->GetBatches<GHistIndexMatrix>(ctx, p)) {
*p_cuts = page.cut;
break;
}
};
auto ellpack = [&]() {
// workaround ellpack being initialized from CPU.
if (p.gpu_id == Context::kCpuId) {
p.gpu_id = ref_->Ctx()->gpu_id;
}
if (p.gpu_id == Context::kCpuId) {
p.gpu_id = 0;
}
for (auto const& page : ref_->GetBatches<EllpackPage>(p)) {
// Fetch cuts from Ellpack.
auto ellpack = [&] {
for (auto const& page : ref->GetBatches<EllpackPage>(ctx, p)) {
GetCutsFromEllpack(page, p_cuts);
break;
}
};
if (ref_->PageExists<GHistIndexMatrix>()) {
if (ref->PageExists<GHistIndexMatrix>() && ref->PageExists<EllpackPage>()) {
// Both exists
if (ctx->IsCPU()) {
csr();
} else {
ellpack();
}
} else if (ref->PageExists<GHistIndexMatrix>()) {
csr();
} else if (ref_->PageExists<EllpackPage>()) {
} else if (ref->PageExists<EllpackPage>()) {
ellpack();
} else {
if (p.gpu_id == Context::kCpuId) {
// None exist
if (ctx->IsCPU()) {
csr();
} else {
ellpack();
}
}
CHECK_EQ(ref_->Info().num_col_, n_features)
CHECK_EQ(ref->Info().num_col_, n_features)
<< "Invalid ref DMatrix, different number of features.";
}
@@ -112,7 +114,8 @@ void SyncFeatureType(std::vector<FeatureType>* p_h_ft) {
}
} // anonymous namespace
void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p,
DataIterHandle iter_handle, float missing,
std::shared_ptr<DMatrix> ref) {
DMatrixProxy* proxy = MakeProxy(proxy_);
CHECK(proxy);
@@ -133,7 +136,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
auto const is_valid = data::IsValidFunctor{missing};
auto nnz_cnt = [&]() {
return HostAdapterDispatch(proxy, [&](auto const& value) {
size_t n_threads = ctx_.Threads();
size_t n_threads = ctx->Threads();
size_t n_features = column_sizes.size();
linalg::Tensor<std::size_t, 2> column_sizes_tloc({n_threads, n_features}, Context::kCpuId);
column_sizes_tloc.Data()->Fill(0ul);
@@ -158,10 +161,10 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
});
};
size_t n_features = 0;
size_t n_batches = 0;
size_t accumulated_rows{0};
size_t nnz{0};
std::uint64_t n_features = 0;
std::size_t n_batches = 0;
std::uint64_t accumulated_rows{0};
std::uint64_t nnz{0};
/**
* CPU impl needs an additional loop for accumulating the column size.
@@ -203,7 +206,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
accumulated_rows = 0;
std::vector<FeatureType> h_ft;
if (ref) {
GetCutsFromRef(ref, Info().num_col_, batch_param_, &cuts);
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
h_ft = ref->Info().feature_types.HostVector();
} else {
size_t i = 0;
@@ -211,9 +214,8 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
if (!p_sketch) {
h_ft = proxy->Info().feature_types.ConstHostVector();
SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
ctx_.Threads()});
p_sketch.reset(new common::HostSketchContainer{ctx, p.max_bin, h_ft, column_sizes,
!proxy->Info().group_ptr_.empty()});
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
@@ -237,15 +239,15 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
/**
* Generate gradient index.
*/
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), batch_param_.max_bin);
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), p.max_bin);
size_t rbegin = 0;
size_t prev_sum = 0;
size_t i = 0;
while (iter.Next()) {
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
this->ghist_->PushAdapterBatch(&ctx_, rbegin, prev_sum, batch, missing, h_ft,
batch_param_.sparse_thresh, Info().num_row_);
this->ghist_->PushAdapterBatch(ctx, rbegin, prev_sum, batch, missing, h_ft, p.sparse_thresh,
Info().num_row_);
});
if (n_batches != 1) {
this->info_.Extend(std::move(proxy->Info()), false, true);
@@ -265,7 +267,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
accumulated_rows = 0;
while (iter.Next()) {
HostAdapterDispatch(proxy, [&](auto const& batch) {
this->ghist_->PushAdapterBatchColumns(&ctx_, batch, missing, accumulated_rows);
this->ghist_->PushAdapterBatchColumns(ctx, batch, missing, accumulated_rows);
});
accumulated_rows += num_rows();
}
@@ -282,11 +284,27 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
Info().feature_types.HostVector() = h_ft;
}
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
CheckParam(param);
BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(Context const* ctx,
BatchParam const& param) {
if (param.Initialized()) {
CheckParam(param);
CHECK(!detail::RegenGHist(param, batch_)) << error::InconsistentMaxBin();
}
if (!ellpack_ && !ghist_) {
LOG(FATAL) << "`QuantileDMatrix` not initialized.";
}
if (!ghist_) {
CHECK(ellpack_);
ghist_ = std::make_shared<GHistIndexMatrix>(&ctx_, Info(), *ellpack_, param);
if (ctx->IsCPU()) {
ghist_ = std::make_shared<GHistIndexMatrix>(ctx, Info(), *ellpack_, param);
} else if (fmat_ctx_.IsCPU()) {
ghist_ = std::make_shared<GHistIndexMatrix>(&fmat_ctx_, Info(), *ellpack_, param);
} else {
// Can happen when QDM is initialized on GPU, but a CPU version is queried by a different QDM
// for cut reference.
auto cpu_ctx = ctx->MakeCPU();
ghist_ = std::make_shared<GHistIndexMatrix>(&cpu_ctx, Info(), *ellpack_, param);
}
}
if (!std::isnan(param.sparse_thresh) &&
@@ -300,8 +318,9 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
return BatchSet<GHistIndexMatrix>(begin_iter);
}
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param) {
for (auto const& page : this->GetGradientIndex(param)) {
BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(Context const* ctx,
BatchParam const& param) {
for (auto const& page : this->GetGradientIndex(ctx, param)) {
auto p_out = std::make_shared<SparsePage>();
p_out->data.Resize(this->Info().num_nonzero_);
p_out->offset.Resize(this->Info().num_row_ + 1);
@@ -336,5 +355,26 @@ BatchSet<ExtSparsePage> IterativeDMatrix::GetExtBatches(BatchParam const& param)
BatchIterator<ExtSparsePage>(new SimpleBatchIteratorImpl<ExtSparsePage>(nullptr));
return BatchSet<ExtSparsePage>(begin_iter);
}
} // namespace data
} // namespace xgboost
#if !defined(XGBOOST_USE_CUDA)
inline void IterativeDMatrix::InitFromCUDA(Context const*, BatchParam const&, DataIterHandle, float,
std::shared_ptr<DMatrix>) {
// silent the warning about unused variables.
(void)(proxy_);
(void)(reset_);
(void)(next_);
common::AssertGPUSupport();
}
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(Context const* ctx,
BatchParam const& param) {
common::AssertGPUSupport();
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
}
inline void GetCutsFromEllpack(EllpackPage const&, common::HistogramCuts*) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace xgboost::data