Define the new device parameter. (#9362)

This commit is contained in:
Jiaming Yuan
2023-07-13 19:30:25 +08:00
committed by GitHub
parent 2d0cd2817e
commit 04aff3af8e
63 changed files with 827 additions and 477 deletions

View File

@@ -84,6 +84,25 @@ bool UpdatersMatched(std::vector<std::string> updater_seq,
return name == up->Name();
});
}
void MismatchedDevices(Context const* booster, Context const* data) {
bool thread_local static logged{false};
if (logged) {
return;
}
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. This might "
"lead to higher memory usage and slower performance. XGBoost is running on: "
<< booster->DeviceName() << ", while the input data is on: " << data->DeviceName()
<< ".\n"
<< R"(Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.
This warning will only be shown once, and subsequent warnings made by the current thread will be
suppressed.
)";
logged = true;
}
} // namespace
void GBTree::Configure(Args const& cfg) {
@@ -208,6 +227,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
bst_target_t const n_groups = model_.learner_model_param->OutputLength();
monitor_.Start("BoostNewTrees");
predt->predictions.SetDevice(ctx_->Ordinal());
auto out = linalg::MakeTensorView(ctx_, &predt->predictions, p_fmat->Info().num_row_,
model_.learner_model_param->OutputLength());
CHECK_NE(n_groups, 0);
@@ -521,18 +541,6 @@ void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds,
}
}
namespace {
inline void MismatchedDevices(Context const* booster, Context const* data) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << booster->DeviceName()
<< ", while the input data is on: " << data->DeviceName() << ".\n"
<< R"(Potential solutions:
- Use a data structure that matches the device ordinal in the booster.
- Set the device for booster before call to inplace_predict.
)";
}
}; // namespace
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
bst_layer_t layer_begin, bst_layer_t layer_end) {
// dispatch to const function.