Return base score as intercept. (#9486)
This commit is contained in:
parent
0bb87b5b35
commit
7f29a238e6
@ -1359,25 +1359,25 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def intercept_(self) -> np.ndarray:
|
def intercept_(self) -> np.ndarray:
|
||||||
"""
|
"""Intercept (bias) property
|
||||||
Intercept (bias) property
|
|
||||||
|
|
||||||
.. note:: Intercept is defined only for linear learners
|
For tree-based model, the returned value is the `base_score`.
|
||||||
|
|
||||||
Intercept (bias) is only defined when the linear model is chosen as base
|
|
||||||
learner (`booster=gblinear`). It is not defined for other base learner types,
|
|
||||||
such as tree learners (`booster=gbtree`).
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
intercept_ : array of shape ``(1,)`` or ``[n_classes]``
|
intercept_ : array of shape ``(1,)`` or ``[n_classes]``
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.get_xgb_params()["booster"] != "gblinear":
|
booster_config = self.get_xgb_params()["booster"]
|
||||||
raise AttributeError(
|
|
||||||
f"Intercept (bias) is not defined for Booster type {self.booster}"
|
|
||||||
)
|
|
||||||
b = self.get_booster()
|
b = self.get_booster()
|
||||||
return np.array(json.loads(b.get_dump(dump_format="json")[0])["bias"])
|
if booster_config != "gblinear": # gbtree, dart
|
||||||
|
config = json.loads(b.save_config())
|
||||||
|
intercept = config["learner"]["learner_model_param"]["base_score"]
|
||||||
|
return np.array([float(intercept)], dtype=np.float32)
|
||||||
|
|
||||||
|
return np.array(
|
||||||
|
json.loads(b.get_dump(dump_format="json")[0])["bias"], dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
PredtT = TypeVar("PredtT", bound=np.ndarray)
|
PredtT = TypeVar("PredtT", bound=np.ndarray)
|
||||||
|
|||||||
@ -1507,6 +1507,7 @@ def test_evaluation_metric():
|
|||||||
# shape check inside the `merror` function
|
# shape check inside the `merror` function
|
||||||
clf.fit(X, y, eval_set=[(X, y)])
|
clf.fit(X, y, eval_set=[(X, y)])
|
||||||
|
|
||||||
|
|
||||||
def test_weighted_evaluation_metric():
|
def test_weighted_evaluation_metric():
|
||||||
from sklearn.datasets import make_hastie_10_2
|
from sklearn.datasets import make_hastie_10_2
|
||||||
from sklearn.metrics import log_loss
|
from sklearn.metrics import log_loss
|
||||||
@ -1544,3 +1545,18 @@ def test_weighted_evaluation_metric():
|
|||||||
internal["validation_0"]["logloss"],
|
internal["validation_0"]["logloss"],
|
||||||
atol=1e-6
|
atol=1e-6
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_intercept() -> None:
|
||||||
|
X, y, w = tm.make_regression(256, 3, use_cupy=False)
|
||||||
|
reg = xgb.XGBRegressor()
|
||||||
|
reg.fit(X, y, sample_weight=w)
|
||||||
|
result = reg.intercept_
|
||||||
|
assert result.dtype == np.float32
|
||||||
|
assert result[0] < 0.5
|
||||||
|
|
||||||
|
reg = xgb.XGBRegressor(booster="gblinear")
|
||||||
|
reg.fit(X, y, sample_weight=w)
|
||||||
|
result = reg.intercept_
|
||||||
|
assert result.dtype == np.float32
|
||||||
|
assert result[0] < 0.5
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user