Fix prediction from loaded pickle. (#4516)
This commit is contained in:
@@ -329,6 +329,7 @@ class LearnerImpl : public Learner {
|
||||
const std::string prefix = "SAVED_PARAM_";
|
||||
if (kv.first.find(prefix) == 0) {
|
||||
const std::string saved_param = kv.first.substr(prefix.length());
|
||||
bool is_gpu_predictor = saved_param == "predictor" && kv.second == "gpu_predictor";
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
if (saved_param == "predictor" || saved_param == "n_gpus"
|
||||
|| saved_param == "gpu_id") {
|
||||
@@ -346,10 +347,14 @@ class LearnerImpl : public Learner {
|
||||
<< " * JVM packages: bst.setParam(\""
|
||||
<< saved_param << "\", [new value])";
|
||||
}
|
||||
#else
|
||||
if (is_gpu_predictor) {
|
||||
cfg_["predictor"] = "cpu_predictor";
|
||||
kv.second = "cpu_predictor";
|
||||
}
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
// NO visiable GPU on current environment
|
||||
if (GPUSet::AllVisible().Size() == 0 &&
|
||||
(saved_param == "predictor" && kv.second == "gpu_predictor")) {
|
||||
if (is_gpu_predictor && GPUSet::AllVisible().Size() == 0) {
|
||||
cfg_["predictor"] = "cpu_predictor";
|
||||
kv.second = "cpu_predictor";
|
||||
}
|
||||
@@ -543,7 +548,7 @@ class LearnerImpl : public Learner {
|
||||
void Predict(DMatrix* data, bool output_margin,
|
||||
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
bool pred_leaf, bool pred_contribs, bool approx_contribs,
|
||||
bool pred_interactions) const override {
|
||||
bool pred_interactions) override {
|
||||
bool multiple_predictions = static_cast<int>(pred_leaf) +
|
||||
static_cast<int>(pred_interactions) +
|
||||
static_cast<int>(pred_contribs);
|
||||
@@ -712,10 +717,11 @@ class LearnerImpl : public Learner {
|
||||
* \param ntree_limit limit number of trees used for boosted tree
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
*/
|
||||
inline void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) const {
|
||||
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) {
|
||||
CHECK(gbm_ != nullptr)
|
||||
<< "Predict must happen after Load or InitModel";
|
||||
ConfigurationWithKnownData(data);
|
||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user