Fix prediction from loaded pickle. (#4516)
This commit is contained in:
parent
fed665ae8a
commit
b48f895027
@ -119,7 +119,7 @@ class Learner : public rabit::Serializable {
|
||||
bool pred_leaf = false,
|
||||
bool pred_contribs = false,
|
||||
bool approx_contribs = false,
|
||||
bool pred_interactions = false) const = 0;
|
||||
bool pred_interactions = false) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Set additional attribute to the Booster.
|
||||
|
||||
@ -329,6 +329,7 @@ class LearnerImpl : public Learner {
|
||||
const std::string prefix = "SAVED_PARAM_";
|
||||
if (kv.first.find(prefix) == 0) {
|
||||
const std::string saved_param = kv.first.substr(prefix.length());
|
||||
bool is_gpu_predictor = saved_param == "predictor" && kv.second == "gpu_predictor";
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
if (saved_param == "predictor" || saved_param == "n_gpus"
|
||||
|| saved_param == "gpu_id") {
|
||||
@ -346,10 +347,14 @@ class LearnerImpl : public Learner {
|
||||
<< " * JVM packages: bst.setParam(\""
|
||||
<< saved_param << "\", [new value])";
|
||||
}
|
||||
#else
|
||||
if (is_gpu_predictor) {
|
||||
cfg_["predictor"] = "cpu_predictor";
|
||||
kv.second = "cpu_predictor";
|
||||
}
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
// NO visiable GPU on current environment
|
||||
if (GPUSet::AllVisible().Size() == 0 &&
|
||||
(saved_param == "predictor" && kv.second == "gpu_predictor")) {
|
||||
if (is_gpu_predictor && GPUSet::AllVisible().Size() == 0) {
|
||||
cfg_["predictor"] = "cpu_predictor";
|
||||
kv.second = "cpu_predictor";
|
||||
}
|
||||
@ -543,7 +548,7 @@ class LearnerImpl : public Learner {
|
||||
void Predict(DMatrix* data, bool output_margin,
|
||||
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
bool pred_leaf, bool pred_contribs, bool approx_contribs,
|
||||
bool pred_interactions) const override {
|
||||
bool pred_interactions) override {
|
||||
bool multiple_predictions = static_cast<int>(pred_leaf) +
|
||||
static_cast<int>(pred_interactions) +
|
||||
static_cast<int>(pred_contribs);
|
||||
@ -712,10 +717,11 @@ class LearnerImpl : public Learner {
|
||||
* \param ntree_limit limit number of trees used for boosted tree
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
*/
|
||||
inline void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) const {
|
||||
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) {
|
||||
CHECK(gbm_ != nullptr)
|
||||
<< "Predict must happen after Load or InitModel";
|
||||
ConfigurationWithKnownData(data);
|
||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||
}
|
||||
|
||||
|
||||
@ -4,7 +4,9 @@ import unittest
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import xgboost as xgb
|
||||
from xgboost import XGBClassifier
|
||||
|
||||
model_path = './model.pkl'
|
||||
|
||||
@ -17,6 +19,17 @@ def build_dataset():
|
||||
return x, y
|
||||
|
||||
|
||||
def save_pickle(bst, path):
|
||||
with open(path, 'wb') as fd:
|
||||
pickle.dump(bst, fd)
|
||||
|
||||
|
||||
def load_pickle(path):
|
||||
with open(path, 'rb') as fd:
|
||||
bst = pickle.load(fd)
|
||||
return bst
|
||||
|
||||
|
||||
class TestPickling(unittest.TestCase):
|
||||
def test_pickling(self):
|
||||
x, y = build_dataset()
|
||||
@ -27,8 +40,7 @@ class TestPickling(unittest.TestCase):
|
||||
'verbosity': 1}
|
||||
bst = xgb.train(param, train_x)
|
||||
|
||||
with open(model_path, 'wb') as fd:
|
||||
pickle.dump(bst, fd)
|
||||
save_pickle(bst, model_path)
|
||||
args = ["pytest",
|
||||
"--verbose",
|
||||
"-s",
|
||||
@ -51,3 +63,30 @@ class TestPickling(unittest.TestCase):
|
||||
status = subprocess.call(command, env=env, shell=True)
|
||||
assert status == 0
|
||||
os.remove(model_path)
|
||||
|
||||
def test_predict_sklearn_pickle(self):
|
||||
x, y = build_dataset()
|
||||
|
||||
kwargs = {'tree_method': 'gpu_hist',
|
||||
'predictor': 'gpu_predictor',
|
||||
'verbosity': 2,
|
||||
'objective': 'binary:logistic',
|
||||
'n_estimators': 10}
|
||||
|
||||
model = XGBClassifier(**kwargs)
|
||||
model.fit(x, y)
|
||||
|
||||
save_pickle(model, "model.pkl")
|
||||
del model
|
||||
|
||||
# load model
|
||||
model: xgb.XGBClassifier = load_pickle("model.pkl")
|
||||
os.remove("model.pkl")
|
||||
|
||||
gpu_pred = model.predict(x, output_margin=True)
|
||||
|
||||
# Switch to CPU predictor
|
||||
bst = model.get_booster()
|
||||
bst.set_param({'predictor': 'cpu_predictor'})
|
||||
cpu_pred = model.predict(x, output_margin=True)
|
||||
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user