Fix inplace predict with fallback when base margin is used. (#9536)

- Copy meta info from proxy DMatrix.
- Use `std::call_once` to emit less warnings.
This commit is contained in:
Jiaming Yuan
2023-09-05 01:04:24 +08:00
committed by GitHub
parent d159ee8547
commit adea842c83
6 changed files with 62 additions and 63 deletions

View File

@@ -3,9 +3,11 @@
*/
#include "error_msg.h"
#include <mutex> // for call_once, once_flag
#include <sstream> // for stringstream
#include "../collective/communicator-inl.h" // for GetRank
#include "xgboost/context.h" // for Context
#include "xgboost/logging.h"
namespace xgboost::error {
@@ -26,34 +28,43 @@ void WarnDeprecatedGPUHist() {
}
void WarnManualUpdater() {
bool static thread_local logged{false};
if (logged) {
return;
}
LOG(WARNING)
<< "You have manually specified the `updater` parameter. The `tree_method` parameter "
"will be ignored. Incorrect sequence of updaters will produce undefined "
"behavior. For common uses, we recommend using `tree_method` parameter instead.";
logged = true;
static std::once_flag flag;
std::call_once(flag, [] {
LOG(WARNING)
<< "You have manually specified the `updater` parameter. The `tree_method` parameter "
"will be ignored. Incorrect sequence of updaters will produce undefined "
"behavior. For common uses, we recommend using `tree_method` parameter instead.";
});
}
void WarnDeprecatedGPUId() {
static thread_local bool logged{false};
if (logged) {
return;
}
auto msg = DeprecatedFunc("gpu_id", "2.0.0", "device");
msg += " E.g. device=cpu/cuda/cuda:0";
LOG(WARNING) << msg;
logged = true;
static std::once_flag flag;
std::call_once(flag, [] {
auto msg = DeprecatedFunc("gpu_id", "2.0.0", "device");
msg += " E.g. device=cpu/cuda/cuda:0";
LOG(WARNING) << msg;
});
}
void WarnEmptyDataset() {
static thread_local bool logged{false};
if (logged) {
return;
}
LOG(WARNING) << "Empty dataset at worker: " << collective::GetRank();
logged = true;
static std::once_flag flag;
std::call_once(flag,
[] { LOG(WARNING) << "Empty dataset at worker: " << collective::GetRank(); });
}
void MismatchedDevices(Context const* booster, Context const* data) {
static std::once_flag flag;
std::call_once(flag, [&] {
LOG(WARNING)
<< "Falling back to prediction using DMatrix due to mismatched devices. This might "
"lead to higher memory usage and slower performance. XGBoost is running on: "
<< booster->DeviceName() << ", while the input data is on: " << data->DeviceName() << ".\n"
<< R"(Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.
This warning will only be shown once.
)";
});
}
} // namespace xgboost::error

View File

@@ -10,7 +10,8 @@
#include <limits> // for numeric_limits
#include <string> // for string
#include "xgboost/base.h" // for bst_feature_t
#include "xgboost/base.h" // for bst_feature_t
#include "xgboost/context.h" // for Context
#include "xgboost/logging.h"
#include "xgboost/string_view.h" // for StringView
@@ -94,5 +95,7 @@ constexpr StringView InvalidCUDAOrdinal() {
return "Invalid device. `device` is required to be CUDA and there must be at least one GPU "
"available for using GPU.";
}
void MismatchedDevices(Context const* booster, Context const* data);
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_