Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -11,10 +11,12 @@
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/error_msg.h" // for InconsistentMaxBin
|
||||
#include "../common/random.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "./simple_batch_iterator.h"
|
||||
#include "adapter.h"
|
||||
#include "batch_utils.h" // for CheckEmpty, RegenGHist
|
||||
#include "gradient_index.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/data.h"
|
||||
@@ -28,7 +30,7 @@ const MetaInfo& SimpleDMatrix::Info() const { return info_; }
|
||||
DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
||||
auto out = new SimpleDMatrix;
|
||||
SparsePage& out_page = *out->sparse_page_;
|
||||
for (auto const &page : this->GetBatches<SparsePage>()) {
|
||||
for (auto const& page : this->GetBatches<SparsePage>()) {
|
||||
auto batch = page.GetView();
|
||||
auto& h_data = out_page.data.HostVector();
|
||||
auto& h_offset = out_page.offset.HostVector();
|
||||
@@ -42,7 +44,7 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
|
||||
out->Info() = this->Info().Slice(ridxs);
|
||||
out->Info().num_nonzero_ = h_offset.back();
|
||||
}
|
||||
out->ctx_ = this->ctx_;
|
||||
out->fmat_ctx_ = this->fmat_ctx_;
|
||||
return out;
|
||||
}
|
||||
|
||||
@@ -52,7 +54,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
auto const slice_size = info_.num_col_ / num_slices;
|
||||
auto const slice_start = slice_size * slice_id;
|
||||
auto const slice_end = (slice_id == num_slices - 1) ? info_.num_col_ : slice_start + slice_size;
|
||||
for (auto const &page : this->GetBatches<SparsePage>()) {
|
||||
for (auto const& page : this->GetBatches<SparsePage>()) {
|
||||
auto batch = page.GetView();
|
||||
auto& h_data = out_page.data.HostVector();
|
||||
auto& h_offset = out_page.offset.HostVector();
|
||||
@@ -60,9 +62,8 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
for (bst_row_t i = 0; i < this->Info().num_row_; i++) {
|
||||
auto inst = batch[i];
|
||||
auto prev_size = h_data.size();
|
||||
std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data), [&](Entry e) {
|
||||
return e.index >= slice_start && e.index < slice_end;
|
||||
});
|
||||
std::copy_if(inst.begin(), inst.end(), std::back_inserter(h_data),
|
||||
[&](Entry e) { return e.index >= slice_start && e.index < slice_end; });
|
||||
rptr += h_data.size() - prev_size;
|
||||
h_offset.emplace_back(rptr);
|
||||
}
|
||||
@@ -73,7 +74,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
return out;
|
||||
}
|
||||
|
||||
void SimpleDMatrix::ReindexFeatures() {
|
||||
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
|
||||
if (info_.IsVerticalFederated()) {
|
||||
std::vector<uint64_t> buffer(collective::GetWorldSize());
|
||||
buffer[collective::GetRank()] = info_.num_col_;
|
||||
@@ -82,72 +83,115 @@ void SimpleDMatrix::ReindexFeatures() {
|
||||
if (offset == 0) {
|
||||
return;
|
||||
}
|
||||
sparse_page_->Reindex(offset, ctx_.Threads());
|
||||
sparse_page_->Reindex(offset, ctx->Threads());
|
||||
}
|
||||
}
|
||||
|
||||
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||
// since csr is the default data structure so `source_` is always available.
|
||||
auto begin_iter = BatchIterator<SparsePage>(
|
||||
new SimpleBatchIteratorImpl<SparsePage>(sparse_page_));
|
||||
auto begin_iter =
|
||||
BatchIterator<SparsePage>(new SimpleBatchIteratorImpl<SparsePage>(sparse_page_));
|
||||
return BatchSet<SparsePage>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
|
||||
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches(Context const* ctx) {
|
||||
// column page doesn't exist, generate it
|
||||
if (!column_page_) {
|
||||
column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads())));
|
||||
column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx->Threads())));
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_));
|
||||
auto begin_iter = BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_));
|
||||
return BatchSet<CSCPage>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
||||
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches(Context const* ctx) {
|
||||
// Sorted column page doesn't exist, generate it
|
||||
if (!sorted_column_page_) {
|
||||
sorted_column_page_.reset(
|
||||
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx_.Threads())));
|
||||
sorted_column_page_->SortRows(ctx_.Threads());
|
||||
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_, ctx->Threads())));
|
||||
sorted_column_page_->SortRows(ctx->Threads());
|
||||
}
|
||||
auto begin_iter = BatchIterator<SortedCSCPage>(
|
||||
new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_));
|
||||
auto begin_iter =
|
||||
BatchIterator<SortedCSCPage>(new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_));
|
||||
return BatchSet<SortedCSCPage>(begin_iter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void CheckEmpty(BatchParam const& l, BatchParam const& r) {
|
||||
if (l == BatchParam{}) {
|
||||
CHECK(r != BatchParam{}) << "Batch parameter is not initialized.";
|
||||
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(Context const* ctx,
|
||||
const BatchParam& param) {
|
||||
detail::CheckEmpty(batch_param_, param);
|
||||
if (ellpack_page_ && param.Initialized() && param.forbid_regen) {
|
||||
if (detail::RegenGHist(batch_param_, param)) {
|
||||
CHECK_EQ(batch_param_.max_bin, param.max_bin) << error::InconsistentMaxBin();
|
||||
}
|
||||
CHECK(!detail::RegenGHist(batch_param_, param));
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
|
||||
// ELLPACK page doesn't exist, generate it
|
||||
CheckEmpty(batch_param_, param);
|
||||
if (!ellpack_page_ || RegenGHist(batch_param_, param)) {
|
||||
CHECK_GE(param.gpu_id, 0);
|
||||
if (!ellpack_page_ || detail::RegenGHist(batch_param_, param)) {
|
||||
// ELLPACK page doesn't exist, generate it
|
||||
LOG(INFO) << "Generating new Ellpack page.";
|
||||
// These places can ask for a ellpack page:
|
||||
// - GPU hist: the ctx must be on CUDA.
|
||||
// - IterativeDMatrix::InitFromCUDA: The ctx must be on CUDA.
|
||||
// - IterativeDMatrix::InitFromCPU: It asks for ellpack only if it exists. It should
|
||||
// not regen, otherwise it indicates a mismatched parameter like max_bin.
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
ellpack_page_.reset(new EllpackPage(this, param));
|
||||
batch_param_ = param;
|
||||
if (ctx->IsCUDA()) {
|
||||
// The context passed in is on GPU, we pick it first since we prioritize the context
|
||||
// in Booster.
|
||||
ellpack_page_.reset(new EllpackPage(ctx, this, param));
|
||||
} else if (fmat_ctx_.IsCUDA()) {
|
||||
// DMatrix was initialized on GPU, we use the context from initialization.
|
||||
ellpack_page_.reset(new EllpackPage(&fmat_ctx_, this, param));
|
||||
} else {
|
||||
// Mismatched parameter, user set a new max_bin during training.
|
||||
auto cuda_ctx = ctx->MakeCUDA();
|
||||
ellpack_page_.reset(new EllpackPage(&cuda_ctx, this, param));
|
||||
}
|
||||
|
||||
batch_param_ = param.MakeCache();
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_));
|
||||
return BatchSet<EllpackPage>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& param) {
|
||||
CheckEmpty(batch_param_, param);
|
||||
if (!gradient_index_ || RegenGHist(batch_param_, param)) {
|
||||
BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(Context const* ctx,
|
||||
const BatchParam& param) {
|
||||
detail::CheckEmpty(batch_param_, param);
|
||||
// Check whether we can regenerate the gradient index. This is to keep the consistency
|
||||
// between evaluation data and training data.
|
||||
if (gradient_index_ && param.Initialized() && param.forbid_regen) {
|
||||
if (detail::RegenGHist(batch_param_, param)) {
|
||||
CHECK_EQ(batch_param_.max_bin, param.max_bin) << error::InconsistentMaxBin();
|
||||
}
|
||||
CHECK(!detail::RegenGHist(batch_param_, param)) << "Inconsistent sparse threshold.";
|
||||
}
|
||||
if (!gradient_index_ || detail::RegenGHist(batch_param_, param)) {
|
||||
// GIDX page doesn't exist, generate it
|
||||
LOG(INFO) << "Generating new Gradient Index.";
|
||||
// These places can ask for a CSR gidx:
|
||||
// - CPU Hist: the ctx must be on CPU.
|
||||
// - IterativeDMatrix::InitFromCPU: The ctx must be on CPU.
|
||||
// - IterativeDMatrix::InitFromCUDA: It asks for gidx only if it exists. It should not
|
||||
// regen, otherwise it indicates a mismatched parameter like max_bin.
|
||||
CHECK_GE(param.max_bin, 2);
|
||||
CHECK_EQ(param.gpu_id, -1);
|
||||
// Used only by approx.
|
||||
auto sorted_sketch = param.regen;
|
||||
gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin, param.sparse_thresh,
|
||||
sorted_sketch, this->ctx_.Threads(), param.hess));
|
||||
batch_param_ = param;
|
||||
if (ctx->IsCPU()) {
|
||||
// The context passed in is on CPU, we pick it first since we prioritize the context
|
||||
// in Booster.
|
||||
gradient_index_.reset(new GHistIndexMatrix{ctx, this, param.max_bin, param.sparse_thresh,
|
||||
sorted_sketch, param.hess});
|
||||
} else if (fmat_ctx_.IsCPU()) {
|
||||
// DMatrix was initialized on CPU, we use the context from initialization.
|
||||
gradient_index_.reset(new GHistIndexMatrix{&fmat_ctx_, this, param.max_bin,
|
||||
param.sparse_thresh, sorted_sketch, param.hess});
|
||||
} else {
|
||||
// Mismatched parameter, user set a new max_bin during training.
|
||||
auto cpu_ctx = ctx->MakeCPU();
|
||||
gradient_index_.reset(new GHistIndexMatrix{&cpu_ctx, this, param.max_bin, param.sparse_thresh,
|
||||
sorted_sketch, param.hess});
|
||||
}
|
||||
|
||||
batch_param_ = param.MakeCache();
|
||||
CHECK_EQ(batch_param_.hess.data(), param.hess.data());
|
||||
}
|
||||
auto begin_iter = BatchIterator<GHistIndexMatrix>(
|
||||
@@ -155,7 +199,7 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(BatchParam const&) {
|
||||
BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(Context const*, BatchParam const&) {
|
||||
auto casted = std::make_shared<ExtSparsePage>(sparse_page_);
|
||||
CHECK(casted);
|
||||
auto begin_iter =
|
||||
@@ -166,7 +210,8 @@ BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(BatchParam const&) {
|
||||
template <typename AdapterT>
|
||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
||||
DataSplitMode data_split_mode) {
|
||||
this->ctx_.nthread = nthread;
|
||||
Context ctx;
|
||||
ctx.Init(Args{{"nthread", std::to_string(nthread)}});
|
||||
|
||||
std::vector<uint64_t> qids;
|
||||
uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||
@@ -176,13 +221,13 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
||||
auto& data_vec = sparse_page_->data.HostVector();
|
||||
uint64_t inferred_num_columns = 0;
|
||||
uint64_t total_batch_size = 0;
|
||||
// batch_size is either number of rows or cols, depending on data layout
|
||||
// batch_size is either number of rows or cols, depending on data layout
|
||||
|
||||
adapter->BeforeFirst();
|
||||
// Iterate over batches of input data
|
||||
while (adapter->Next()) {
|
||||
auto& batch = adapter->Value();
|
||||
auto batch_max_columns = sparse_page_->Push(batch, missing, ctx_.Threads());
|
||||
auto batch_max_columns = sparse_page_->Push(batch, missing, ctx.Threads());
|
||||
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
|
||||
total_batch_size += batch.Size();
|
||||
// Append meta information if available
|
||||
@@ -229,19 +274,18 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
}
|
||||
|
||||
|
||||
// Synchronise worker columns
|
||||
info_.data_split_mode = data_split_mode;
|
||||
ReindexFeatures();
|
||||
ReindexFeatures(&ctx);
|
||||
info_.SynchronizeNumberOfColumns();
|
||||
|
||||
if (adapter->NumRows() == kAdapterUnknownSize) {
|
||||
using IteratorAdapterT
|
||||
= IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>;
|
||||
using IteratorAdapterT =
|
||||
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>;
|
||||
// If AdapterT is either IteratorAdapter or FileAdapter type, use the total batch size to
|
||||
// determine the correct number of rows, as offset_vec may be too short
|
||||
if (std::is_same<AdapterT, IteratorAdapterT>::value
|
||||
|| std::is_same<AdapterT, FileAdapter>::value) {
|
||||
if (std::is_same<AdapterT, IteratorAdapterT>::value ||
|
||||
std::is_same<AdapterT, FileAdapter>::value) {
|
||||
info_.num_row_ = total_batch_size;
|
||||
// Ensure offset_vec.size() - 1 == [number of rows]
|
||||
while (offset_vec.size() - 1 < total_batch_size) {
|
||||
@@ -265,9 +309,11 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
|
||||
info_.num_nonzero_ = data_vec.size();
|
||||
|
||||
// Sort the index for row partitioners used by variuos tree methods.
|
||||
if (!sparse_page_->IsIndicesSorted(this->ctx_.Threads())) {
|
||||
sparse_page_->SortIndices(this->ctx_.Threads());
|
||||
if (!sparse_page_->IsIndicesSorted(ctx.Threads())) {
|
||||
sparse_page_->SortIndices(ctx.Threads());
|
||||
}
|
||||
|
||||
this->fmat_ctx_ = ctx;
|
||||
}
|
||||
|
||||
SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
|
||||
@@ -280,12 +326,12 @@ SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
|
||||
}
|
||||
|
||||
void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||
int tmagic = kMagic;
|
||||
fo->Write(tmagic);
|
||||
info_.SaveBinary(fo.get());
|
||||
fo->Write(sparse_page_->offset.HostVector());
|
||||
fo->Write(sparse_page_->data.HostVector());
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||
int tmagic = kMagic;
|
||||
fo->Write(tmagic);
|
||||
info_.SaveBinary(fo.get());
|
||||
fo->Write(sparse_page_->offset.HostVector());
|
||||
fo->Write(sparse_page_->data.HostVector());
|
||||
}
|
||||
|
||||
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, int nthread,
|
||||
@@ -305,14 +351,14 @@ template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing,
|
||||
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread,
|
||||
DataSplitMode data_split_mode);
|
||||
template SimpleDMatrix::SimpleDMatrix(
|
||||
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
|
||||
*adapter,
|
||||
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
||||
float missing, int nthread, DataSplitMode data_split_mode);
|
||||
|
||||
template <>
|
||||
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread,
|
||||
DataSplitMode data_split_mode) {
|
||||
ctx_.nthread = nthread;
|
||||
Context ctx;
|
||||
ctx.nthread = nthread;
|
||||
|
||||
auto& offset_vec = sparse_page_->offset.HostVector();
|
||||
auto& data_vec = sparse_page_->data.HostVector();
|
||||
@@ -326,7 +372,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
|
||||
size_t num_elements = 0;
|
||||
size_t num_rows = 0;
|
||||
// Import Arrow RecordBatches
|
||||
#pragma omp parallel for reduction(+ : num_elements, num_rows) num_threads(ctx_.Threads())
|
||||
#pragma omp parallel for reduction(+ : num_elements, num_rows) num_threads(ctx.Threads())
|
||||
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
||||
num_elements += batches[i]->Import(missing);
|
||||
num_rows += batches[i]->Size();
|
||||
@@ -348,7 +394,7 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
|
||||
data_vec.resize(total_elements);
|
||||
offset_vec.resize(total_batch_size + 1);
|
||||
// Copy data into DMatrix
|
||||
#pragma omp parallel num_threads(ctx_.Threads())
|
||||
#pragma omp parallel num_threads(ctx.Threads())
|
||||
{
|
||||
#pragma omp for nowait
|
||||
for (int i = 0; i < static_cast<int>(batches.size()); ++i) { // NOLINT
|
||||
@@ -372,12 +418,14 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
|
||||
// Synchronise worker columns
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
info_.data_split_mode = data_split_mode;
|
||||
ReindexFeatures();
|
||||
ReindexFeatures(&ctx);
|
||||
info_.SynchronizeNumberOfColumns();
|
||||
|
||||
info_.num_row_ = total_batch_size;
|
||||
info_.num_nonzero_ = data_vec.size();
|
||||
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);
|
||||
|
||||
fmat_ctx_ = ctx;
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user