Fix prediction configuration. (#7159)
After the predictor parameter was added to the constructor, this configuration was broken.
This commit is contained in:
parent
9600ca83f3
commit
3f38d983a6
@ -798,8 +798,8 @@ class XGBModel(XGBModelBase):
|
||||
# error with incompatible data type.
|
||||
# Inplace predict doesn't handle as many data types as DMatrix, but it's
|
||||
# sufficient for dask interface where input is simpiler.
|
||||
params = self.get_params()
|
||||
if params.get("predictor", None) is None and self.booster != "gblinear":
|
||||
predictor = self.get_params().get("predictor", None)
|
||||
if predictor in ("auto", None) and self.booster != "gblinear":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ -1254,3 +1254,20 @@ def test_estimator_reg(estimator, check):
|
||||
estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params())
|
||||
|
||||
check(estimator)
|
||||
|
||||
|
||||
def test_prediction_config():
|
||||
reg = xgb.XGBRegressor()
|
||||
assert reg._can_use_inplace_predict() is True
|
||||
|
||||
reg.set_params(predictor="cpu_predictor")
|
||||
assert reg._can_use_inplace_predict() is False
|
||||
|
||||
reg.set_params(predictor="auto")
|
||||
assert reg._can_use_inplace_predict() is True
|
||||
|
||||
reg.set_params(predictor=None)
|
||||
assert reg._can_use_inplace_predict() is True
|
||||
|
||||
reg.set_params(booster="gblinear")
|
||||
assert reg._can_use_inplace_predict() is False
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user