Define the new device parameter. (#9362)
This commit is contained in:
@@ -84,6 +84,25 @@ bool UpdatersMatched(std::vector<std::string> updater_seq,
|
||||
return name == up->Name();
|
||||
});
|
||||
}
|
||||
|
||||
void MismatchedDevices(Context const* booster, Context const* data) {
|
||||
bool thread_local static logged{false};
|
||||
if (logged) {
|
||||
return;
|
||||
}
|
||||
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. This might "
|
||||
"lead to higher memory usage and slower performance. XGBoost is running on: "
|
||||
<< booster->DeviceName() << ", while the input data is on: " << data->DeviceName()
|
||||
<< ".\n"
|
||||
<< R"(Potential solutions:
|
||||
- Use a data structure that matches the device ordinal in the booster.
|
||||
- Set the device for booster before call to inplace_predict.
|
||||
|
||||
This warning will only be shown once, and subsequent warnings made by the current thread will be
|
||||
suppressed.
|
||||
)";
|
||||
logged = true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GBTree::Configure(Args const& cfg) {
|
||||
@@ -208,6 +227,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
|
||||
bst_target_t const n_groups = model_.learner_model_param->OutputLength();
|
||||
monitor_.Start("BoostNewTrees");
|
||||
|
||||
predt->predictions.SetDevice(ctx_->Ordinal());
|
||||
auto out = linalg::MakeTensorView(ctx_, &predt->predictions, p_fmat->Info().num_row_,
|
||||
model_.learner_model_param->OutputLength());
|
||||
CHECK_NE(n_groups, 0);
|
||||
@@ -521,18 +541,6 @@ void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
inline void MismatchedDevices(Context const* booster, Context const* data) {
|
||||
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
|
||||
<< "is running on: " << booster->DeviceName()
|
||||
<< ", while the input data is on: " << data->DeviceName() << ".\n"
|
||||
<< R"(Potential solutions:
|
||||
- Use a data structure that matches the device ordinal in the booster.
|
||||
- Set the device for booster before call to inplace_predict.
|
||||
)";
|
||||
}
|
||||
}; // namespace
|
||||
|
||||
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end) {
|
||||
// dispatch to const function.
|
||||
|
||||
Reference in New Issue
Block a user