Improve test coverage with predictor configuration. (#9354)

* Improve test coverage with predictor configuration.

- Test with ext memory.
- Test with QDM.
- Test with dart.
This commit is contained in:
Jiaming Yuan
2023-07-05 15:17:22 +08:00
committed by GitHub
parent 6c9c8a9001
commit 645037e376
17 changed files with 280 additions and 79 deletions

View File

@@ -18,7 +18,7 @@
#include <vector>
#include "../common/common.h"
#include "../common/error_msg.h" // for UnknownDevice
#include "../common/error_msg.h" // for UnknownDevice, InplacePredictProxy
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "../common/timer.h"
@@ -542,6 +542,18 @@ void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds,
}
}
namespace {
inline void MismatchedDevices(Context const* booster, Context const* data) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << booster->DeviceName()
<< ", while the input data is on: " << data->DeviceName() << ".\n"
<< R"(Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.
)";
}
}; // namespace
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
bst_layer_t layer_begin, bst_layer_t layer_end) {
// dispatch to const function.
@@ -555,24 +567,26 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
if (p_m->Ctx()->Device() != this->ctx_->Device()) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << this->ctx_->DeviceName()
<< ", while the input data is on: " << p_m->Ctx()->DeviceName() << ".";
MismatchedDevices(this->ctx_, p_m->Ctx());
CHECK_EQ(out_preds->version, 0);
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
auto any_adapter = proxy->Adapter();
CHECK(proxy) << error::InplacePredictProxy();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
this->PredictBatchImpl(p_fmat.get(), out_preds, false, layer_begin, layer_end);
return;
}
if (this->ctx_->IsCPU()) {
this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
} else if (p_m->Ctx()->IsCUDA()) {
CHECK(this->gpu_predictor_);
this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
} else {
LOG(FATAL) << error::UnknownDevice();
bool known_type = this->ctx_->DispatchDevice(
[&, begin = tree_begin, end = tree_end] {
return this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
},
[&, begin = tree_begin, end = tree_end] {
return this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, begin, end);
});
if (!known_type) {
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
CHECK(proxy) << error::InplacePredictProxy();
LOG(FATAL) << "Unknown data type for inplace prediction:" << proxy->Adapter().type().name();
}
}
@@ -808,11 +822,9 @@ class Dart : public GBTree {
auto n_groups = model_.learner_model_param->num_output_group;
if (ctx_->Device() != p_fmat->Ctx()->Device()) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << this->ctx_->DeviceName()
<< ", while the input data is on: " << p_fmat->Ctx()->DeviceName() << ".";
MismatchedDevices(ctx_, p_fmat->Ctx());
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
auto any_adapter = proxy->Adapter();
CHECK(proxy) << error::InplacePredictProxy();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
this->PredictBatchImpl(p_fmat.get(), p_out_preds, false, layer_begin, layer_end);
return;
@@ -825,20 +837,15 @@ class Dart : public GBTree {
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
auto get_predictor = [&]() -> Predictor const* {
if (ctx_->IsCPU()) {
return cpu_predictor_.get();
} else if (ctx_->IsCUDA()) {
CHECK(this->gpu_predictor_);
return gpu_predictor_.get();
} else {
LOG(FATAL) << error::UnknownDevice();
return nullptr;
}
};
auto predict_impl = [&](size_t i) {
predts.predictions.Fill(0);
bool success{get_predictor()->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)};
bool success = this->ctx_->DispatchDevice(
[&] {
return cpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
},
[&] {
return gpu_predictor_->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
});
CHECK(success) << msg;
};
@@ -846,7 +853,15 @@ class Dart : public GBTree {
for (bst_tree_t i = tree_begin; i < tree_end; ++i) {
predict_impl(i);
if (i == tree_begin) {
get_predictor()->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
this->ctx_->DispatchDevice(
[&] {
this->cpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
},
[&] {
this->gpu_predictor_->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
});
}
// Multiple the tree weight
auto w = this->weight_drop_.at(i);