Fix prediction configuration. (#7159)

After the predictor parameter was added to the constructor, this configuration was broken.
This commit is contained in:
Jiaming Yuan 2021-08-11 16:34:36 +08:00 committed by GitHub
parent 9600ca83f3
commit 3f38d983a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 2 deletions

View File

@ -798,8 +798,8 @@ class XGBModel(XGBModelBase):
# error with incompatible data type.
# Inplace predict doesn't handle as many data types as DMatrix, but it's
# sufficient for dask interface where input is simpiler.
params = self.get_params()
if params.get("predictor", None) is None and self.booster != "gblinear":
predictor = self.get_params().get("predictor", None)
if predictor in ("auto", None) and self.booster != "gblinear":
return True
return False

View File

@ -1254,3 +1254,20 @@ def test_estimator_reg(estimator, check):
estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params())
check(estimator)
def test_prediction_config():
reg = xgb.XGBRegressor()
assert reg._can_use_inplace_predict() is True
reg.set_params(predictor="cpu_predictor")
assert reg._can_use_inplace_predict() is False
reg.set_params(predictor="auto")
assert reg._can_use_inplace_predict() is True
reg.set_params(predictor=None)
assert reg._can_use_inplace_predict() is True
reg.set_params(booster="gblinear")
assert reg._can_use_inplace_predict() is False