Improve test coverage with predictor configuration. (#9354)

* Improve test coverage with predictor configuration.

- Test with ext memory.
- Test with QDM.
- Test with dart.
This commit is contained in:
Jiaming Yuan
2023-07-05 15:17:22 +08:00
committed by GitHub
parent 6c9c8a9001
commit 645037e376
17 changed files with 280 additions and 79 deletions

View File

@@ -47,5 +47,9 @@ inline void MaxFeatureSize(std::uint64_t n_features) {
<< "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<bst_feature_t>::max() << " features or greater";
}
constexpr StringView InplacePredictProxy() {
return "Inplace predict accepts only DMatrixProxy as input.";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@@ -68,6 +68,7 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
}
std::size_t Write(GHistIndexMatrix const& page, common::AlignedFileWriteStream* fo) override {
CHECK_NE(page.index.Size(), 0) << "Empty page is not supported.";
std::size_t bytes = 0;
bytes += WriteHistogramCuts(page.cut, fo);
// indptr

View File

@@ -1,10 +1,9 @@
/*!
* Copyright 2021-2022 by XGBoost Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#include "gradient_index_page_source.h"
namespace xgboost {
namespace data {
namespace xgboost::data {
void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
@@ -21,5 +20,4 @@ void GradientIndexPageSource::Fetch() {
this->WriteCache();
}
}
} // namespace data
} // namespace xgboost
} // namespace xgboost::data

View File

@@ -18,7 +18,7 @@
#include <vector>
#include "../common/common.h"
#include "../common/error_msg.h" // for UnknownDevice
#include "../common/error_msg.h" // for UnknownDevice, InplacePredictProxy
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "../common/timer.h"
@@ -542,6 +542,18 @@ void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds,
}
}
namespace {
inline void MismatchedDevices(Context const* booster, Context const* data) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << booster->DeviceName()
<< ", while the input data is on: " << data->DeviceName() << ".\n"
<< R"(Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.
)";
}
}; // namespace
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
bst_layer_t layer_begin, bst_layer_t layer_end) {
// dispatch to const function.
@@ -555,24 +567,26 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
if (p_m->Ctx()->Device() != this->ctx_->Device()) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << this->ctx_->DeviceName()
<< ", while the input data is on: " << p_m->Ctx()->DeviceName() << ".";
MismatchedDevices(this->ctx_, p_m->Ctx());
CHECK_EQ(out_preds->version, 0);
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
auto any_adapter = proxy->Adapter();
CHECK(proxy) << error::InplacePredictProxy();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
this->PredictBatchImpl(p_fmat.get(), out_preds, false, layer_begin, layer_end);
return;
}
if (this->ctx_->IsCPU()) {
this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
} else if (p_m->Ctx()->IsCUDA()) {
CHECK(this->gpu_predictor_);
this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
} else {
LOG(FATAL) << error::UnknownDevice();
bool known_type = this->ctx_->DispatchDevice(
[&, begin = tree_begin, end = tree_end] {
return this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
},
[&, begin = tree_begin, end = tree_end] {
return this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
});
if (!known_type) {
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
CHECK(proxy) << error::InplacePredictProxy();
LOG(FATAL) << "Unknown data type for inplace prediction:" << proxy->Adapter().type().name();
}
}
@@ -808,11 +822,9 @@ class Dart : public GBTree {
auto n_groups = model_.learner_model_param->num_output_group;
if (ctx_->Device() != p_fmat->Ctx()->Device()) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << this->ctx_->DeviceName()
<< ", while the input data is on: " << p_fmat->Ctx()->DeviceName() << ".";
MismatchedDevices(ctx_, p_fmat->Ctx());
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
auto any_adapter = proxy->Adapter();
CHECK(proxy) << error::InplacePredictProxy();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
this->PredictBatchImpl(p_fmat.get(), p_out_preds, false, layer_begin, layer_end);
return;
@@ -825,20 +837,15 @@ class Dart : public GBTree {
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
auto get_predictor = [&]() -> Predictor const* {
if (ctx_->IsCPU()) {
return cpu_predictor_.get();
} else if (ctx_->IsCUDA()) {
CHECK(this->gpu_predictor_);
return gpu_predictor_.get();
} else {
LOG(FATAL) << error::UnknownDevice();
return nullptr;
}
};
auto predict_impl = [&](size_t i) {
predts.predictions.Fill(0);
bool success{get_predictor()->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)};
bool success = this->ctx_->DispatchDevice(
[&] {
return cpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
},
[&] {
return gpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
});
CHECK(success) << msg;
};
@@ -846,7 +853,15 @@ class Dart : public GBTree {
for (bst_tree_t i = tree_begin; i < tree_end; ++i) {
predict_impl(i);
if (i == tree_begin) {
get_predictor()->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
this->ctx_->DispatchDevice(
[&] {
this->cpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
},
[&] {
this->gpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
});
}
// Multiple the tree weight
auto w = this->weight_drop_.at(i);

View File

@@ -16,6 +16,7 @@
#include "../common/bitfield.h" // for RBitField8
#include "../common/categorical.h" // for IsCat, Decision
#include "../common/common.h" // for DivRoundUp
#include "../common/error_msg.h" // for InplacePredictProxy
#include "../common/math.h" // for CheckNAN
#include "../common/threading_utils.h" // for ParallelFor
#include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter
@@ -741,7 +742,7 @@ class CPUPredictor : public Predictor {
PredictionCacheEntry *out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
CHECK(proxy)<< error::InplacePredictProxy();
CHECK(!p_m->Info().IsColumnSplit())
<< "Inplace predict support for column-wise data split is not yet implemented.";
auto x = proxy->Adapter();

View File

@@ -15,8 +15,9 @@
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/common.h"
#include "../common/cuda_context.cuh"
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh"
#include "../common/error_msg.h" // for InplacePredictProxy
#include "../data/device_adapter.cuh"
#include "../data/ellpack_page.cuh"
#include "../data/proxy_dmatrix.h"
@@ -989,7 +990,7 @@ class GPUPredictor : public xgboost::Predictor {
PredictionCacheEntry* out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
auto proxy = dynamic_cast<data::DMatrixProxy*>(p_m.get());
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
CHECK(proxy) << error::InplacePredictProxy();
auto x = proxy->Adapter();
if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
this->DispatchedInplacePredict<data::CupyAdapter,