[backport] Fix inplace predict with fallback when base margin is used. (#9536) (#9548)

- Copy meta info from proxy DMatrix.
- Use `std::call_once` to emit less warnings.
This commit is contained in:
Jiaming Yuan
2023-09-05 23:38:06 +08:00
committed by GitHub
parent b67c2ed96d
commit 3fde9361d7
6 changed files with 62 additions and 63 deletions

View File

@@ -85,25 +85,6 @@ bool UpdatersMatched(std::vector<std::string> updater_seq,
return name == up->Name();
});
}
void MismatchedDevices(Context const* booster, Context const* data) {
bool thread_local static logged{false};
if (logged) {
return;
}
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. This might "
"lead to higher memory usage and slower performance. XGBoost is running on: "
<< booster->DeviceName() << ", while the input data is on: " << data->DeviceName()
<< ".\n"
<< R"(Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.
This warning will only be shown once for each thread. Subsequent warnings made by the
current thread will be suppressed.
)";
logged = true;
}
} // namespace
void GBTree::Configure(Args const& cfg) {
@@ -554,7 +535,7 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
if (p_m->Ctx()->Device() != this->ctx_->Device()) {
MismatchedDevices(this->ctx_, p_m->Ctx());
error::MismatchedDevices(this->ctx_, p_m->Ctx());
CHECK_EQ(out_preds->version, 0);
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
CHECK(proxy) << error::InplacePredictProxy();
@@ -807,7 +788,7 @@ class Dart : public GBTree {
auto n_groups = model_.learner_model_param->num_output_group;
if (ctx_->Device() != p_fmat->Ctx()->Device()) {
MismatchedDevices(ctx_, p_fmat->Ctx());
error::MismatchedDevices(ctx_, p_fmat->Ctx());
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
CHECK(proxy) << error::InplacePredictProxy();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);