Fix inplace predict with fallback when base margin is used. (#9536)
- Copy meta info from proxy DMatrix. - Use `std::call_once` to emit less warnings.
This commit is contained in:
@@ -3,9 +3,11 @@
|
||||
*/
|
||||
#include "error_msg.h"
|
||||
|
||||
#include <mutex> // for call_once, once_flag
|
||||
#include <sstream> // for stringstream
|
||||
|
||||
#include "../collective/communicator-inl.h" // for GetRank
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost::error {
|
||||
@@ -26,34 +28,43 @@ void WarnDeprecatedGPUHist() {
|
||||
}
|
||||
|
||||
void WarnManualUpdater() {
|
||||
bool static thread_local logged{false};
|
||||
if (logged) {
|
||||
return;
|
||||
}
|
||||
LOG(WARNING)
|
||||
<< "You have manually specified the `updater` parameter. The `tree_method` parameter "
|
||||
"will be ignored. Incorrect sequence of updaters will produce undefined "
|
||||
"behavior. For common uses, we recommend using `tree_method` parameter instead.";
|
||||
logged = true;
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, [] {
|
||||
LOG(WARNING)
|
||||
<< "You have manually specified the `updater` parameter. The `tree_method` parameter "
|
||||
"will be ignored. Incorrect sequence of updaters will produce undefined "
|
||||
"behavior. For common uses, we recommend using `tree_method` parameter instead.";
|
||||
});
|
||||
}
|
||||
|
||||
void WarnDeprecatedGPUId() {
|
||||
static thread_local bool logged{false};
|
||||
if (logged) {
|
||||
return;
|
||||
}
|
||||
auto msg = DeprecatedFunc("gpu_id", "2.0.0", "device");
|
||||
msg += " E.g. device=cpu/cuda/cuda:0";
|
||||
LOG(WARNING) << msg;
|
||||
logged = true;
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, [] {
|
||||
auto msg = DeprecatedFunc("gpu_id", "2.0.0", "device");
|
||||
msg += " E.g. device=cpu/cuda/cuda:0";
|
||||
LOG(WARNING) << msg;
|
||||
});
|
||||
}
|
||||
|
||||
void WarnEmptyDataset() {
|
||||
static thread_local bool logged{false};
|
||||
if (logged) {
|
||||
return;
|
||||
}
|
||||
LOG(WARNING) << "Empty dataset at worker: " << collective::GetRank();
|
||||
logged = true;
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag,
|
||||
[] { LOG(WARNING) << "Empty dataset at worker: " << collective::GetRank(); });
|
||||
}
|
||||
|
||||
void MismatchedDevices(Context const* booster, Context const* data) {
|
||||
static std::once_flag flag;
|
||||
std::call_once(flag, [&] {
|
||||
LOG(WARNING)
|
||||
<< "Falling back to prediction using DMatrix due to mismatched devices. This might "
|
||||
"lead to higher memory usage and slower performance. XGBoost is running on: "
|
||||
<< booster->DeviceName() << ", while the input data is on: " << data->DeviceName() << ".\n"
|
||||
<< R"(Potential solutions:
|
||||
- Use a data structure that matches the device ordinal in the booster.
|
||||
- Set the device for booster before call to inplace_predict.
|
||||
|
||||
This warning will only be shown once.
|
||||
)";
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::error
|
||||
|
||||
@@ -10,7 +10,8 @@
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string> // for string
|
||||
|
||||
#include "xgboost/base.h" // for bst_feature_t
|
||||
#include "xgboost/base.h" // for bst_feature_t
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
@@ -94,5 +95,7 @@ constexpr StringView InvalidCUDAOrdinal() {
|
||||
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "
|
||||
"available for using GPU.";
|
||||
}
|
||||
|
||||
void MismatchedDevices(Context const* booster, Context const* data);
|
||||
} // namespace xgboost::error
|
||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
@@ -55,6 +55,7 @@ std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
|
||||
}
|
||||
|
||||
CHECK(p_fmat) << "Failed to fallback.";
|
||||
p_fmat->Info() = proxy->Info().Copy();
|
||||
return p_fmat;
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
|
||||
@@ -85,25 +85,6 @@ bool UpdatersMatched(std::vector<std::string> updater_seq,
|
||||
return name == up->Name();
|
||||
});
|
||||
}
|
||||
|
||||
void MismatchedDevices(Context const* booster, Context const* data) {
|
||||
bool thread_local static logged{false};
|
||||
if (logged) {
|
||||
return;
|
||||
}
|
||||
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. This might "
|
||||
"lead to higher memory usage and slower performance. XGBoost is running on: "
|
||||
<< booster->DeviceName() << ", while the input data is on: " << data->DeviceName()
|
||||
<< ".\n"
|
||||
<< R"(Potential solutions:
|
||||
- Use a data structure that matches the device ordinal in the booster.
|
||||
- Set the device for booster before call to inplace_predict.
|
||||
|
||||
This warning will only be shown once for each thread. Subsequent warnings made by the
|
||||
current thread will be suppressed.
|
||||
)";
|
||||
logged = true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GBTree::Configure(Args const& cfg) {
|
||||
@@ -557,7 +538,7 @@ void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
|
||||
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
|
||||
if (p_m->Ctx()->Device() != this->ctx_->Device()) {
|
||||
MismatchedDevices(this->ctx_, p_m->Ctx());
|
||||
error::MismatchedDevices(this->ctx_, p_m->Ctx());
|
||||
CHECK_EQ(out_preds->version, 0);
|
||||
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
|
||||
CHECK(proxy) << error::InplacePredictProxy();
|
||||
@@ -810,7 +791,7 @@ class Dart : public GBTree {
|
||||
auto n_groups = model_.learner_model_param->num_output_group;
|
||||
|
||||
if (ctx_->Device() != p_fmat->Ctx()->Device()) {
|
||||
MismatchedDevices(ctx_, p_fmat->Ctx());
|
||||
error::MismatchedDevices(ctx_, p_fmat->Ctx());
|
||||
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
|
||||
CHECK(proxy) << error::InplacePredictProxy();
|
||||
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
|
||||
|
||||
Reference in New Issue
Block a user