Fix prediction from loaded pickle. (#4516)

This commit is contained in:
Jiaming Yuan
2019-05-30 15:05:09 +08:00
committed by GitHub
parent fed665ae8a
commit b48f895027
3 changed files with 53 additions and 8 deletions

View File

@@ -329,6 +329,7 @@ class LearnerImpl : public Learner {
const std::string prefix = "SAVED_PARAM_";
if (kv.first.find(prefix) == 0) {
const std::string saved_param = kv.first.substr(prefix.length());
bool is_gpu_predictor = saved_param == "predictor" && kv.second == "gpu_predictor";
#ifdef XGBOOST_USE_CUDA
if (saved_param == "predictor" || saved_param == "n_gpus"
|| saved_param == "gpu_id") {
@@ -346,10 +347,14 @@ class LearnerImpl : public Learner {
<< " * JVM packages: bst.setParam(\""
<< saved_param << "\", [new value])";
}
#else
if (is_gpu_predictor) {
cfg_["predictor"] = "cpu_predictor";
kv.second = "cpu_predictor";
}
#endif // XGBOOST_USE_CUDA
// NO visiable GPU on current environment
if (GPUSet::AllVisible().Size() == 0 &&
(saved_param == "predictor" && kv.second == "gpu_predictor")) {
if (is_gpu_predictor && GPUSet::AllVisible().Size() == 0) {
cfg_["predictor"] = "cpu_predictor";
kv.second = "cpu_predictor";
}
@@ -543,7 +548,7 @@ class LearnerImpl : public Learner {
void Predict(DMatrix* data, bool output_margin,
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
bool pred_leaf, bool pred_contribs, bool approx_contribs,
bool pred_interactions) const override {
bool pred_interactions) override {
bool multiple_predictions = static_cast<int>(pred_leaf) +
static_cast<int>(pred_interactions) +
static_cast<int>(pred_contribs);
@@ -712,10 +717,11 @@ class LearnerImpl : public Learner {
* \param ntree_limit limit number of trees used for boosted tree
* predictor, when it equals 0, this means we are using all the trees
*/
inline void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit = 0) const {
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit = 0) {
CHECK(gbm_ != nullptr)
<< "Predict must happen after Load or InitModel";
ConfigurationWithKnownData(data);
gbm_->PredictBatch(data, out_preds, ntree_limit);
}