From 3f38d983a671632bfb0dbdcdead6f2aa408a06f2 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 11 Aug 2021 16:34:36 +0800 Subject: [PATCH] Fix prediction configuration. (#7159) After the predictor parameter was added to the constructor, this configuration was broken. --- python-package/xgboost/sklearn.py | 4 ++-- tests/python/test_with_sklearn.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index cbce3dc05..ef543f5f6 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -798,8 +798,8 @@ class XGBModel(XGBModelBase): # error with incompatible data type. # Inplace predict doesn't handle as many data types as DMatrix, but it's # sufficient for dask interface where input is simpiler. - params = self.get_params() - if params.get("predictor", None) is None and self.booster != "gblinear": + predictor = self.get_params().get("predictor", None) + if predictor in ("auto", None) and self.booster != "gblinear": return True return False diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index cf31929cc..afc1d857c 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1254,3 +1254,20 @@ def test_estimator_reg(estimator, check): estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params()) check(estimator) + + +def test_prediction_config(): + reg = xgb.XGBRegressor() + assert reg._can_use_inplace_predict() is True + + reg.set_params(predictor="cpu_predictor") + assert reg._can_use_inplace_predict() is False + + reg.set_params(predictor="auto") + assert reg._can_use_inplace_predict() is True + + reg.set_params(predictor=None) + assert reg._can_use_inplace_predict() is True + + reg.set_params(booster="gblinear") + assert reg._can_use_inplace_predict() is False