diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 0958803f3..20f5747e2 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -709,7 +709,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase): evals=evals, early_stopping_rounds=early_stopping_rounds, evals_result=evals_result, obj=obj, feval=feval, - verbose_eval=verbose, xgb_model=None, + verbose_eval=verbose, xgb_model=xgb_model, callbacks=callbacks) self.objective = xgb_options["objective"] diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 0826493b6..eb1b95664 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -609,3 +609,42 @@ def test_RFECV(): scale_pos_weight=0.5, silent=True) rfecv = RFECV(estimator=bst, step=1, cv=3, scoring='neg_log_loss') rfecv.fit(X, y) + + +def test_XGBClassifier_resume(): + from sklearn.datasets import load_breast_cancer + from sklearn.metrics import log_loss + + with TemporaryDirectory() as tempdir: + model1_path = os.path.join(tempdir, 'test_XGBClassifier.model') + model1_booster_path = os.path.join(tempdir, 'test_XGBClassifier.booster') + + X, Y = load_breast_cancer(return_X_y=True) + + model1 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8) + model1.fit(X, Y) + + pred1 = model1.predict(X) + log_loss1 = log_loss(pred1, Y) + + # file name of stored xgb model + model1.save_model(model1_path) + model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8) + model2.fit(X, Y, xgb_model=model1_path) + + pred2 = model2.predict(X) + log_loss2 = log_loss(pred2, Y) + + assert np.any(pred1 != pred2) + assert log_loss1 > log_loss2 + + # file name of 'Booster' instance Xgb model + model1.get_booster().save_model(model1_booster_path) + model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8) + model2.fit(X, Y, xgb_model=model1_booster_path) + + pred2 = model2.predict(X) + log_loss2 = log_loss(pred2, Y) + + assert np.any(pred1 != pred2) + assert log_loss1 > log_loss2