Fix prediction configuration. (#7159)
After the predictor parameter was added to the constructor, this configuration was broken.
This commit is contained in:
parent
9600ca83f3
commit
3f38d983a6
@ -798,8 +798,8 @@ class XGBModel(XGBModelBase):
|
|||||||
# error with incompatible data type.
|
# error with incompatible data type.
|
||||||
# Inplace predict doesn't handle as many data types as DMatrix, but it's
|
# Inplace predict doesn't handle as many data types as DMatrix, but it's
|
||||||
# sufficient for dask interface where input is simpiler.
|
# sufficient for dask interface where input is simpiler.
|
||||||
params = self.get_params()
|
predictor = self.get_params().get("predictor", None)
|
||||||
if params.get("predictor", None) is None and self.booster != "gblinear":
|
if predictor in ("auto", None) and self.booster != "gblinear":
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -1254,3 +1254,20 @@ def test_estimator_reg(estimator, check):
|
|||||||
estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params())
|
estimator.set_params(**xgb.XGBRegressor().fit(X, y).get_params())
|
||||||
|
|
||||||
check(estimator)
|
check(estimator)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prediction_config():
|
||||||
|
reg = xgb.XGBRegressor()
|
||||||
|
assert reg._can_use_inplace_predict() is True
|
||||||
|
|
||||||
|
reg.set_params(predictor="cpu_predictor")
|
||||||
|
assert reg._can_use_inplace_predict() is False
|
||||||
|
|
||||||
|
reg.set_params(predictor="auto")
|
||||||
|
assert reg._can_use_inplace_predict() is True
|
||||||
|
|
||||||
|
reg.set_params(predictor=None)
|
||||||
|
assert reg._can_use_inplace_predict() is True
|
||||||
|
|
||||||
|
reg.set_params(booster="gblinear")
|
||||||
|
assert reg._can_use_inplace_predict() is False
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user