Fix prediction from loaded pickle. (#4516)
This commit is contained in:
@@ -4,7 +4,9 @@ import unittest
|
||||
import numpy as np
|
||||
import subprocess
|
||||
import os
|
||||
import sys
|
||||
import xgboost as xgb
|
||||
from xgboost import XGBClassifier
|
||||
|
||||
model_path = './model.pkl'
|
||||
|
||||
@@ -17,6 +19,17 @@ def build_dataset():
|
||||
return x, y
|
||||
|
||||
|
||||
def save_pickle(bst, path):
|
||||
with open(path, 'wb') as fd:
|
||||
pickle.dump(bst, fd)
|
||||
|
||||
|
||||
def load_pickle(path):
|
||||
with open(path, 'rb') as fd:
|
||||
bst = pickle.load(fd)
|
||||
return bst
|
||||
|
||||
|
||||
class TestPickling(unittest.TestCase):
|
||||
def test_pickling(self):
|
||||
x, y = build_dataset()
|
||||
@@ -27,8 +40,7 @@ class TestPickling(unittest.TestCase):
|
||||
'verbosity': 1}
|
||||
bst = xgb.train(param, train_x)
|
||||
|
||||
with open(model_path, 'wb') as fd:
|
||||
pickle.dump(bst, fd)
|
||||
save_pickle(bst, model_path)
|
||||
args = ["pytest",
|
||||
"--verbose",
|
||||
"-s",
|
||||
@@ -51,3 +63,30 @@ class TestPickling(unittest.TestCase):
|
||||
status = subprocess.call(command, env=env, shell=True)
|
||||
assert status == 0
|
||||
os.remove(model_path)
|
||||
|
||||
def test_predict_sklearn_pickle(self):
|
||||
x, y = build_dataset()
|
||||
|
||||
kwargs = {'tree_method': 'gpu_hist',
|
||||
'predictor': 'gpu_predictor',
|
||||
'verbosity': 2,
|
||||
'objective': 'binary:logistic',
|
||||
'n_estimators': 10}
|
||||
|
||||
model = XGBClassifier(**kwargs)
|
||||
model.fit(x, y)
|
||||
|
||||
save_pickle(model, "model.pkl")
|
||||
del model
|
||||
|
||||
# load model
|
||||
model: xgb.XGBClassifier = load_pickle("model.pkl")
|
||||
os.remove("model.pkl")
|
||||
|
||||
gpu_pred = model.predict(x, output_margin=True)
|
||||
|
||||
# Switch to CPU predictor
|
||||
bst = model.get_booster()
|
||||
bst.set_param({'predictor': 'cpu_predictor'})
|
||||
cpu_pred = model.predict(x, output_margin=True)
|
||||
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
|
||||
|
||||
Reference in New Issue
Block a user