Add base_margin for evaluation dataset. (#6591)
* Add base margin to evaluation datasets. * Unify the code base for evaluation matrices.
This commit is contained in:
parent
4bf23c2391
commit
740d042255
@ -38,7 +38,8 @@ from .core import Objective, Metric
|
|||||||
from .core import _deprecate_positional_args
|
from .core import _deprecate_positional_args
|
||||||
from .training import train as worker_train
|
from .training import train as worker_train
|
||||||
from .tracker import RabitTracker, get_host_ip
|
from .tracker import RabitTracker, get_host_ip
|
||||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator
|
from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase
|
||||||
|
from .sklearn import _wrap_evaluation_matrices, _objective_decorator
|
||||||
from .sklearn import XGBRankerMixIn
|
from .sklearn import XGBRankerMixIn
|
||||||
from .sklearn import xgboost_model_doc
|
from .sklearn import xgboost_model_doc
|
||||||
from .sklearn import _cls_predict_proba
|
from .sklearn import _cls_predict_proba
|
||||||
@ -588,7 +589,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
|||||||
weight: Optional[_DaskCollection] = None,
|
weight: Optional[_DaskCollection] = None,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
missing: float = None,
|
missing: float = None,
|
||||||
silent: bool = False,
|
silent: bool = False, # disable=unused-argument
|
||||||
feature_names: Optional[Union[str, List[str]]] = None,
|
feature_names: Optional[Union[str, List[str]]] = None,
|
||||||
feature_types: Optional[Union[Any, List[Any]]] = None,
|
feature_types: Optional[Union[Any, List[Any]]] = None,
|
||||||
max_bin: int = 256,
|
max_bin: int = 256,
|
||||||
@ -1292,44 +1293,24 @@ def inplace_predict(
|
|||||||
missing=missing)
|
missing=missing)
|
||||||
|
|
||||||
|
|
||||||
async def _evaluation_matrices(
|
async def _async_wrap_evaluation_matrices(
|
||||||
client: "distributed.Client",
|
client: "distributed.Client", **kwargs: Any
|
||||||
validation_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
|
||||||
sample_weight: Optional[List[_DaskCollection]],
|
"""A switch function for async environment."""
|
||||||
sample_qid: Optional[List[_DaskCollection]],
|
def _inner(**kwargs: Any) -> DaskDMatrix:
|
||||||
missing: float
|
m = DaskDMatrix(client=client, **kwargs)
|
||||||
) -> Optional[List[Tuple[DaskDMatrix, str]]]:
|
return m
|
||||||
'''
|
train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs)
|
||||||
Parameters
|
train_dmatrix = await train_dmatrix
|
||||||
----------
|
if evals is None:
|
||||||
validation_set: list of tuples
|
return train_dmatrix, evals
|
||||||
Each tuple contains a validation dataset including input X and label y.
|
awaited = []
|
||||||
E.g.:
|
for e in evals:
|
||||||
|
if e[0] is train_dmatrix: # already awaited
|
||||||
.. code-block:: python
|
awaited.append(e)
|
||||||
|
continue
|
||||||
[(X_0, y_0), (X_1, y_1), ... ]
|
awaited.append((await e[0], e[1]))
|
||||||
|
return train_dmatrix, awaited
|
||||||
sample_weights: list of arrays
|
|
||||||
The weight vector for validation data.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
evals: list of validation DMatrix
|
|
||||||
'''
|
|
||||||
evals: Optional[List[Tuple[DaskDMatrix, str]]] = []
|
|
||||||
if validation_set is not None:
|
|
||||||
assert isinstance(validation_set, list)
|
|
||||||
for i, e in enumerate(validation_set):
|
|
||||||
w = sample_weight[i] if sample_weight is not None else None
|
|
||||||
qid = sample_qid[i] if sample_qid is not None else None
|
|
||||||
dmat = await DaskDMatrix(client=client, data=e[0], label=e[1],
|
|
||||||
weight=w, missing=missing, qid=qid)
|
|
||||||
assert isinstance(evals, list)
|
|
||||||
evals.append((dmat, 'validation_{}'.format(i)))
|
|
||||||
else:
|
|
||||||
evals = None
|
|
||||||
return evals
|
|
||||||
|
|
||||||
|
|
||||||
class DaskScikitLearnBase(XGBModel):
|
class DaskScikitLearnBase(XGBModel):
|
||||||
@ -1337,7 +1318,6 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
|
|
||||||
_client = None
|
_client = None
|
||||||
|
|
||||||
@_deprecate_positional_args
|
|
||||||
async def _predict_async(
|
async def _predict_async(
|
||||||
self, data: _DaskCollection,
|
self, data: _DaskCollection,
|
||||||
output_margin: bool = False,
|
output_margin: bool = False,
|
||||||
@ -1404,25 +1384,30 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
eval_metric: Optional[Union[str, List[str], Metric]],
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
||||||
|
base_margin_eval_set: Optional[List[_DaskCollection]],
|
||||||
early_stopping_rounds: int,
|
early_stopping_rounds: int,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
xgb_model: Optional[Union[Booster, XGBModel]],
|
xgb_model: Optional[Union[Booster, XGBModel]],
|
||||||
feature_weights: Optional[_DaskCollection],
|
feature_weights: Optional[_DaskCollection],
|
||||||
callbacks: Optional[List[TrainingCallback]],
|
callbacks: Optional[List[TrainingCallback]],
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
dtrain = await DaskDMatrix(
|
params = self.get_xgb_params()
|
||||||
|
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||||
client=self.client,
|
client=self.client,
|
||||||
data=X,
|
X=X,
|
||||||
label=y,
|
y=y,
|
||||||
weight=sample_weight,
|
group=None,
|
||||||
|
qid=None,
|
||||||
|
sample_weight=sample_weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
feature_weights=feature_weights,
|
feature_weights=feature_weights,
|
||||||
|
eval_set=eval_set,
|
||||||
|
sample_weight_eval_set=sample_weight_eval_set,
|
||||||
|
base_margin_eval_set=base_margin_eval_set,
|
||||||
|
eval_group=None,
|
||||||
|
eval_qid=None,
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
)
|
)
|
||||||
params = self.get_xgb_params()
|
|
||||||
evals = await _evaluation_matrices(
|
|
||||||
self.client, eval_set, sample_weight_eval_set, None, self.missing
|
|
||||||
)
|
|
||||||
|
|
||||||
if callable(self.objective):
|
if callable(self.objective):
|
||||||
obj = _objective_decorator(self.objective)
|
obj = _objective_decorator(self.objective)
|
||||||
@ -1449,7 +1434,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
self.evals_result_ = results["history"]
|
self.evals_result_ = results["history"]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring, disable=unused-argument
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
@ -1464,25 +1449,13 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||||
|
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[List[TrainingCallback]] = None
|
||||||
) -> "DaskXGBRegressor":
|
) -> "DaskXGBRegressor":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self.client.sync(
|
args = {k: v for k, v in locals().items() if k != "self"}
|
||||||
self._fit_async,
|
return self.client.sync(self._fit_async, **args)
|
||||||
X=X,
|
|
||||||
y=y,
|
|
||||||
sample_weight=sample_weight,
|
|
||||||
base_margin=base_margin,
|
|
||||||
eval_set=eval_set,
|
|
||||||
eval_metric=eval_metric,
|
|
||||||
sample_weight_eval_set=sample_weight_eval_set,
|
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
|
||||||
verbose=verbose,
|
|
||||||
xgb_model=xgb_model,
|
|
||||||
feature_weights=feature_weights,
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
@ -1497,20 +1470,30 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
eval_metric: Optional[Union[str, List[str], Metric]],
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
||||||
|
base_margin_eval_set: Optional[List[_DaskCollection]],
|
||||||
early_stopping_rounds: int,
|
early_stopping_rounds: int,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
xgb_model: Optional[Union[Booster, XGBModel]],
|
xgb_model: Optional[Union[Booster, XGBModel]],
|
||||||
feature_weights: Optional[_DaskCollection],
|
feature_weights: Optional[_DaskCollection],
|
||||||
callbacks: Optional[List[TrainingCallback]]
|
callbacks: Optional[List[TrainingCallback]]
|
||||||
) -> "DaskXGBClassifier":
|
) -> "DaskXGBClassifier":
|
||||||
dtrain = await DaskDMatrix(client=self.client,
|
params = self.get_xgb_params()
|
||||||
data=X,
|
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||||
label=y,
|
self.client,
|
||||||
weight=sample_weight,
|
X=X,
|
||||||
|
y=y,
|
||||||
|
group=None,
|
||||||
|
qid=None,
|
||||||
|
sample_weight=sample_weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
feature_weights=feature_weights,
|
feature_weights=feature_weights,
|
||||||
missing=self.missing)
|
eval_set=eval_set,
|
||||||
params = self.get_xgb_params()
|
sample_weight_eval_set=sample_weight_eval_set,
|
||||||
|
base_margin_eval_set=base_margin_eval_set,
|
||||||
|
eval_group=None,
|
||||||
|
eval_qid=None,
|
||||||
|
missing=self.missing,
|
||||||
|
)
|
||||||
|
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
if isinstance(y, (da.Array)):
|
if isinstance(y, (da.Array)):
|
||||||
@ -1525,11 +1508,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
else:
|
else:
|
||||||
params["objective"] = "binary:logistic"
|
params["objective"] = "binary:logistic"
|
||||||
|
|
||||||
evals = await _evaluation_matrices(self.client, eval_set,
|
|
||||||
sample_weight_eval_set,
|
|
||||||
None,
|
|
||||||
self.missing)
|
|
||||||
|
|
||||||
if callable(self.objective):
|
if callable(self.objective):
|
||||||
obj = _objective_decorator(self.objective)
|
obj = _objective_decorator(self.objective)
|
||||||
else:
|
else:
|
||||||
@ -1561,6 +1539,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
self.evals_result_ = results['history']
|
self.evals_result_ = results['history']
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
@ -1575,25 +1554,13 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
xgb_model: Optional[Union[Booster, XGBModel]] = None,
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||||
|
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[List[TrainingCallback]] = None
|
||||||
) -> "DaskXGBClassifier":
|
) -> "DaskXGBClassifier":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self.client.sync(
|
args = {k: v for k, v in locals().items() if k != 'self'}
|
||||||
self._fit_async,
|
return self.client.sync(self._fit_async, **args)
|
||||||
X=X,
|
|
||||||
y=y,
|
|
||||||
sample_weight=sample_weight,
|
|
||||||
base_margin=base_margin,
|
|
||||||
eval_set=eval_set,
|
|
||||||
eval_metric=eval_metric,
|
|
||||||
sample_weight_eval_set=sample_weight_eval_set,
|
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
|
||||||
verbose=verbose,
|
|
||||||
xgb_model=xgb_model,
|
|
||||||
feature_weights=feature_weights,
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _predict_proba_async(
|
async def _predict_proba_async(
|
||||||
self,
|
self,
|
||||||
@ -1613,7 +1580,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
output_margin=output_margin)
|
output_margin=output_margin)
|
||||||
return _cls_predict_proba(self.objective, pred_probs, da.vstack)
|
return _cls_predict_proba(self.objective, pred_probs, da.vstack)
|
||||||
|
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-function-docstring
|
||||||
def predict_proba(
|
def predict_proba(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DaskCollection,
|
||||||
@ -1632,6 +1599,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
base_margin=base_margin
|
base_margin=base_margin
|
||||||
)
|
)
|
||||||
|
predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__
|
||||||
|
|
||||||
async def _predict_async(
|
async def _predict_async(
|
||||||
self, data: _DaskCollection,
|
self, data: _DaskCollection,
|
||||||
@ -1673,11 +1641,14 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DaskCollection,
|
||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
|
group: Optional[_DaskCollection],
|
||||||
qid: Optional[_DaskCollection],
|
qid: Optional[_DaskCollection],
|
||||||
sample_weight: Optional[_DaskCollection],
|
sample_weight: Optional[_DaskCollection],
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]],
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
sample_weight_eval_set: Optional[List[_DaskCollection]],
|
||||||
|
base_margin_eval_set: Optional[List[_DaskCollection]],
|
||||||
|
eval_group: Optional[List[_DaskCollection]],
|
||||||
eval_qid: Optional[List[_DaskCollection]],
|
eval_qid: Optional[List[_DaskCollection]],
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]],
|
eval_metric: Optional[Union[str, List[str], Metric]],
|
||||||
early_stopping_rounds: int,
|
early_stopping_rounds: int,
|
||||||
@ -1686,22 +1657,26 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
feature_weights: Optional[_DaskCollection],
|
feature_weights: Optional[_DaskCollection],
|
||||||
callbacks: Optional[List[TrainingCallback]],
|
callbacks: Optional[List[TrainingCallback]],
|
||||||
) -> "DaskXGBRanker":
|
) -> "DaskXGBRanker":
|
||||||
dtrain = await DaskDMatrix(
|
msg = "Use `qid` instead of `group` on dask interface."
|
||||||
client=self.client,
|
if not (group is None and eval_group is None):
|
||||||
data=X,
|
raise ValueError(msg)
|
||||||
label=y,
|
if qid is None:
|
||||||
|
raise ValueError("`qid` is required for ranking.")
|
||||||
|
params = self.get_xgb_params()
|
||||||
|
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||||
|
self.client,
|
||||||
|
X=X,
|
||||||
|
y=y,
|
||||||
|
group=None,
|
||||||
qid=qid,
|
qid=qid,
|
||||||
weight=sample_weight,
|
sample_weight=sample_weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
feature_weights=feature_weights,
|
feature_weights=feature_weights,
|
||||||
missing=self.missing,
|
eval_set=eval_set,
|
||||||
)
|
sample_weight_eval_set=sample_weight_eval_set,
|
||||||
params = self.get_xgb_params()
|
base_margin_eval_set=base_margin_eval_set,
|
||||||
evals = await _evaluation_matrices(
|
eval_group=None,
|
||||||
self.client,
|
eval_qid=eval_qid,
|
||||||
eval_set,
|
|
||||||
sample_weight_eval_set,
|
|
||||||
sample_qid=eval_qid,
|
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
)
|
)
|
||||||
if eval_metric is not None:
|
if eval_metric is not None:
|
||||||
@ -1728,8 +1703,9 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
self.evals_result_ = results["history"]
|
self.evals_result_ = results["history"]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument, arguments-differ
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit( # pylint: disable=arguments-differ
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DaskCollection,
|
||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
@ -1739,39 +1715,20 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
sample_weight: Optional[_DaskCollection] = None,
|
sample_weight: Optional[_DaskCollection] = None,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None,
|
||||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
|
||||||
eval_group: Optional[List[_DaskCollection]] = None,
|
eval_group: Optional[List[_DaskCollection]] = None,
|
||||||
eval_qid: Optional[List[_DaskCollection]] = None,
|
eval_qid: Optional[List[_DaskCollection]] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
eval_metric: Optional[Union[str, List[str], Metric]] = None,
|
||||||
early_stopping_rounds: int = None,
|
early_stopping_rounds: int = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
xgb_model: Optional[Union[XGBModel, Booster]] = None,
|
xgb_model: Optional[Union[XGBModel, Booster]] = None,
|
||||||
|
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||||
|
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||||
feature_weights: Optional[_DaskCollection] = None,
|
feature_weights: Optional[_DaskCollection] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None
|
callbacks: Optional[List[TrainingCallback]] = None
|
||||||
) -> "DaskXGBRanker":
|
) -> "DaskXGBRanker":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
msg = "Use `qid` instead of `group` on dask interface."
|
args = {k: v for k, v in locals().items() if k != 'self'}
|
||||||
if not (group is None and eval_group is None):
|
return self.client.sync(self._fit_async, **args)
|
||||||
raise ValueError(msg)
|
|
||||||
if qid is None:
|
|
||||||
raise ValueError("`qid` is required for ranking.")
|
|
||||||
return self.client.sync(
|
|
||||||
self._fit_async,
|
|
||||||
X=X,
|
|
||||||
y=y,
|
|
||||||
qid=qid,
|
|
||||||
sample_weight=sample_weight,
|
|
||||||
base_margin=base_margin,
|
|
||||||
eval_set=eval_set,
|
|
||||||
sample_weight_eval_set=sample_weight_eval_set,
|
|
||||||
eval_qid=eval_qid,
|
|
||||||
eval_metric=eval_metric,
|
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
|
||||||
verbose=verbose,
|
|
||||||
xgb_model=xgb_model,
|
|
||||||
feature_weights=feature_weights,
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
||||||
fit.__doc__ = XGBRanker.fit.__doc__
|
fit.__doc__ = XGBRanker.fit.__doc__
|
||||||
|
|||||||
@ -212,6 +212,105 @@ Parameters
|
|||||||
return adddoc
|
return adddoc
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_evaluation_matrices(
|
||||||
|
missing: float,
|
||||||
|
X: Any,
|
||||||
|
y: Any,
|
||||||
|
group: Optional[Any],
|
||||||
|
qid: Optional[Any],
|
||||||
|
sample_weight: Optional[Any],
|
||||||
|
base_margin: Optional[Any],
|
||||||
|
feature_weights: Optional[Any],
|
||||||
|
eval_set: Optional[List[Tuple[Any, Any]]],
|
||||||
|
sample_weight_eval_set: Optional[List[Any]],
|
||||||
|
base_margin_eval_set: Optional[List[Any]],
|
||||||
|
eval_group: Optional[List[Any]],
|
||||||
|
eval_qid: Optional[List[Any]],
|
||||||
|
create_dmatrix: Callable,
|
||||||
|
label_transform: Callable = lambda x: x,
|
||||||
|
) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]:
|
||||||
|
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
|
||||||
|
|
||||||
|
"""
|
||||||
|
train_dmatrix = create_dmatrix(
|
||||||
|
data=X,
|
||||||
|
label=label_transform(y),
|
||||||
|
group=group,
|
||||||
|
qid=qid,
|
||||||
|
weight=sample_weight,
|
||||||
|
base_margin=base_margin,
|
||||||
|
feature_weights=feature_weights,
|
||||||
|
missing=missing,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate_or_none(meta: Optional[List], name: str) -> List:
|
||||||
|
if meta is None:
|
||||||
|
return [None] * len(eval_set)
|
||||||
|
if len(meta) != len(eval_set):
|
||||||
|
raise ValueError(
|
||||||
|
f"{name}'s length does not eqaul to `eval_set`, " +
|
||||||
|
f"expecting {len(eval_set)}, got {len(meta)}"
|
||||||
|
)
|
||||||
|
return meta
|
||||||
|
|
||||||
|
if eval_set is not None:
|
||||||
|
sample_weight_eval_set = validate_or_none(
|
||||||
|
sample_weight_eval_set, "sample_weight_eval_set"
|
||||||
|
)
|
||||||
|
base_margin_eval_set = validate_or_none(
|
||||||
|
base_margin_eval_set, "base_margin_eval_set"
|
||||||
|
)
|
||||||
|
eval_group = validate_or_none(eval_group, "eval_group")
|
||||||
|
eval_qid = validate_or_none(eval_qid, "eval_qid")
|
||||||
|
|
||||||
|
evals = []
|
||||||
|
for i, (valid_X, valid_y) in enumerate(eval_set):
|
||||||
|
# Skip the duplicated entry.
|
||||||
|
if all(
|
||||||
|
(
|
||||||
|
valid_X is X, valid_y is y,
|
||||||
|
sample_weight_eval_set[i] is sample_weight,
|
||||||
|
base_margin_eval_set[i] is base_margin,
|
||||||
|
eval_group[i] is group,
|
||||||
|
eval_qid[i] is qid
|
||||||
|
)
|
||||||
|
):
|
||||||
|
evals.append(train_dmatrix)
|
||||||
|
else:
|
||||||
|
m = create_dmatrix(
|
||||||
|
data=valid_X,
|
||||||
|
label=label_transform(valid_y),
|
||||||
|
weight=sample_weight_eval_set[i],
|
||||||
|
group=eval_group[i],
|
||||||
|
qid=eval_qid[i],
|
||||||
|
base_margin=base_margin_eval_set[i],
|
||||||
|
missing=missing,
|
||||||
|
)
|
||||||
|
evals.append(m)
|
||||||
|
nevals = len(evals)
|
||||||
|
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
||||||
|
evals = list(zip(evals, eval_names))
|
||||||
|
else:
|
||||||
|
if any(
|
||||||
|
[
|
||||||
|
meta is not None
|
||||||
|
for meta in [
|
||||||
|
sample_weight_eval_set,
|
||||||
|
base_margin_eval_set,
|
||||||
|
eval_group,
|
||||||
|
eval_qid,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"`eval_set` is not set but one of the other evaluation meta info is "
|
||||||
|
"not None."
|
||||||
|
)
|
||||||
|
evals = []
|
||||||
|
|
||||||
|
return train_dmatrix, evals
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
@xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""",
|
||||||
['estimators', 'model', 'objective'])
|
['estimators', 'model', 'objective'])
|
||||||
class XGBModel(XGBModelBase):
|
class XGBModel(XGBModelBase):
|
||||||
@ -281,69 +380,6 @@ class XGBModel(XGBModelBase):
|
|||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
self.validate_parameters = validate_parameters
|
self.validate_parameters = validate_parameters
|
||||||
|
|
||||||
def _wrap_evaluation_matrices(
|
|
||||||
self, X, y,
|
|
||||||
group,
|
|
||||||
qid,
|
|
||||||
sample_weight,
|
|
||||||
base_margin,
|
|
||||||
feature_weights,
|
|
||||||
eval_set,
|
|
||||||
sample_weight_eval_set,
|
|
||||||
eval_group,
|
|
||||||
eval_qid,
|
|
||||||
label_transform=lambda x: x
|
|
||||||
) -> Tuple[DMatrix, Optional[List[Tuple[DMatrix, str]]]]:
|
|
||||||
'''Convert array_like evaluation matrices into DMatrix, group and qid are only used for
|
|
||||||
valiation but not set in this function
|
|
||||||
|
|
||||||
'''
|
|
||||||
if sample_weight_eval_set is not None:
|
|
||||||
assert eval_set is not None
|
|
||||||
assert len(sample_weight_eval_set) == len(eval_set)
|
|
||||||
if eval_group is not None:
|
|
||||||
assert eval_set is not None
|
|
||||||
assert len(eval_group) == len(eval_set)
|
|
||||||
|
|
||||||
y = label_transform(y)
|
|
||||||
train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight,
|
|
||||||
base_margin=base_margin,
|
|
||||||
missing=self.missing, nthread=self.n_jobs)
|
|
||||||
train_dmatrix.set_info(feature_weights=feature_weights)
|
|
||||||
|
|
||||||
if eval_set is not None:
|
|
||||||
if sample_weight_eval_set is None:
|
|
||||||
sample_weight_eval_set = [None] * len(eval_set)
|
|
||||||
if eval_group is None:
|
|
||||||
eval_group = [None] * len(eval_set)
|
|
||||||
if eval_qid is None:
|
|
||||||
eval_qid = [None] * len(eval_set)
|
|
||||||
|
|
||||||
evals = []
|
|
||||||
for i, (valid_X, valid_y) in enumerate(eval_set):
|
|
||||||
# Skip the duplicated entry.
|
|
||||||
if (
|
|
||||||
valid_X is X and
|
|
||||||
valid_y is y and
|
|
||||||
sample_weight_eval_set[i] is sample_weight and
|
|
||||||
eval_group[i] is group and
|
|
||||||
eval_qid[i] is qid
|
|
||||||
):
|
|
||||||
evals.append(train_dmatrix)
|
|
||||||
else:
|
|
||||||
m = DMatrix(valid_X,
|
|
||||||
label=label_transform(valid_y),
|
|
||||||
missing=self.missing, weight=sample_weight_eval_set[i],
|
|
||||||
nthread=self.n_jobs)
|
|
||||||
evals.append(m)
|
|
||||||
|
|
||||||
nevals = len(evals)
|
|
||||||
eval_names = ["validation_{}".format(i) for i in range(nevals)]
|
|
||||||
evals = list(zip(evals, eval_names))
|
|
||||||
else:
|
|
||||||
evals = ()
|
|
||||||
return train_dmatrix, evals
|
|
||||||
|
|
||||||
def _more_tags(self):
|
def _more_tags(self):
|
||||||
'''Tags used for scikit-learn data validation.'''
|
'''Tags used for scikit-learn data validation.'''
|
||||||
return {'allow_nan': True, 'no_validation': True}
|
return {'allow_nan': True, 'no_validation': True}
|
||||||
@ -580,17 +616,29 @@ class XGBModel(XGBModelBase):
|
|||||||
self.evals_result_ = evals_result
|
self.evals_result_ = evals_result
|
||||||
|
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(self, X, y, *, sample_weight=None, base_margin=None,
|
def fit(
|
||||||
eval_set=None, eval_metric=None, early_stopping_rounds=None,
|
self,
|
||||||
verbose=True, xgb_model=None, sample_weight_eval_set=None,
|
X,
|
||||||
|
y,
|
||||||
|
*,
|
||||||
|
sample_weight=None,
|
||||||
|
base_margin=None,
|
||||||
|
eval_set=None,
|
||||||
|
eval_metric=None,
|
||||||
|
early_stopping_rounds=None,
|
||||||
|
verbose=True,
|
||||||
|
xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None,
|
||||||
|
sample_weight_eval_set=None,
|
||||||
|
base_margin_eval_set=None,
|
||||||
feature_weights=None,
|
feature_weights=None,
|
||||||
callbacks=None):
|
callbacks=None
|
||||||
|
):
|
||||||
# pylint: disable=invalid-name,attribute-defined-outside-init
|
# pylint: disable=invalid-name,attribute-defined-outside-init
|
||||||
"""Fit gradient boosting model.
|
"""Fit gradient boosting model.
|
||||||
|
|
||||||
Note that calling ``fit()`` multiple times will cause the model object to be re-fit from
|
Note that calling ``fit()`` multiple times will cause the model object to be
|
||||||
scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model``
|
re-fit from scratch. To resume training from a previous checkpoint, explicitly
|
||||||
argument.
|
pass ``xgb_model`` argument.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -611,12 +659,11 @@ class XGBModel(XGBModelBase):
|
|||||||
doc/parameter.rst.
|
doc/parameter.rst.
|
||||||
If a list of str, should be the list of multiple built-in evaluation metrics
|
If a list of str, should be the list of multiple built-in evaluation metrics
|
||||||
to use.
|
to use.
|
||||||
If callable, a custom evaluation metric. The call
|
If callable, a custom evaluation metric. The call signature is
|
||||||
signature is ``func(y_predicted, y_true)`` where ``y_true`` will be a
|
``func(y_predicted, y_true)`` where ``y_true`` will be a DMatrix object such
|
||||||
DMatrix object such that you may need to call the ``get_label``
|
that you may need to call the ``get_label`` method. It must return a str,
|
||||||
method. It must return a str, value pair where the str is a name
|
value pair where the str is a name for the evaluation and value is the value
|
||||||
for the evaluation and value is the value of the evaluation
|
of the evaluation function. The callable custom objective is always minimized.
|
||||||
function. The callable custom objective is always minimized.
|
|
||||||
early_stopping_rounds : int
|
early_stopping_rounds : int
|
||||||
Activates early stopping. Validation metric needs to improve at least once in
|
Activates early stopping. Validation metric needs to improve at least once in
|
||||||
every **early_stopping_rounds** round(s) to continue training.
|
every **early_stopping_rounds** round(s) to continue training.
|
||||||
@ -629,14 +676,17 @@ class XGBModel(XGBModelBase):
|
|||||||
If early stopping occurs, the model will have three additional fields:
|
If early stopping occurs, the model will have three additional fields:
|
||||||
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
|
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
|
||||||
verbose : bool
|
verbose : bool
|
||||||
If `verbose` and an evaluation set is used, writes the evaluation
|
If `verbose` and an evaluation set is used, writes the evaluation metric
|
||||||
metric measured on the validation set to stderr.
|
measured on the validation set to stderr.
|
||||||
xgb_model : Union[str, Booster, XGBModel]
|
xgb_model :
|
||||||
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
|
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
|
||||||
loaded before training (allows training continuation).
|
loaded before training (allows training continuation).
|
||||||
sample_weight_eval_set : list, optional
|
sample_weight_eval_set : list, optional
|
||||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
A list of the form [L_1, L_2, ..., L_n], where each L_i is an array like
|
||||||
instance weights on the i-th validation set.
|
object storing instance weights for the i-th validation set.
|
||||||
|
base_margin_eval_set : list, optional
|
||||||
|
A list of the form [M_1, M_2, ..., M_n], where each M_i is an array like
|
||||||
|
object storing base margin for the i-th validation set.
|
||||||
feature_weights: array_like
|
feature_weights: array_like
|
||||||
Weight for each feature, defines the probability of each feature being
|
Weight for each feature, defines the probability of each feature being
|
||||||
selected when colsample is being used. All values must be greater than 0,
|
selected when colsample is being used. All values must be greater than 0,
|
||||||
@ -655,12 +705,21 @@ class XGBModel(XGBModelBase):
|
|||||||
"""
|
"""
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
|
|
||||||
train_dmatrix, evals = self._wrap_evaluation_matrices(
|
train_dmatrix, evals = _wrap_evaluation_matrices(
|
||||||
X, y, group=None, qid=None, sample_weight=sample_weight,
|
missing=self.missing,
|
||||||
|
X=X,
|
||||||
|
y=y,
|
||||||
|
group=None,
|
||||||
|
qid=None,
|
||||||
|
sample_weight=sample_weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
feature_weights=feature_weights, eval_set=eval_set,
|
feature_weights=feature_weights,
|
||||||
|
eval_set=eval_set,
|
||||||
sample_weight_eval_set=sample_weight_eval_set,
|
sample_weight_eval_set=sample_weight_eval_set,
|
||||||
eval_group=None, eval_qid=None
|
base_margin_eval_set=base_margin_eval_set,
|
||||||
|
eval_group=None,
|
||||||
|
eval_qid=None,
|
||||||
|
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||||
)
|
)
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
|
|
||||||
@ -671,13 +730,19 @@ class XGBModel(XGBModelBase):
|
|||||||
obj = None
|
obj = None
|
||||||
|
|
||||||
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
||||||
self._Booster = train(params, train_dmatrix,
|
self._Booster = train(
|
||||||
self.get_num_boosting_rounds(), evals=evals,
|
params,
|
||||||
|
train_dmatrix,
|
||||||
|
self.get_num_boosting_rounds(),
|
||||||
|
evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result,
|
evals_result=evals_result,
|
||||||
obj=obj, feval=feval,
|
obj=obj,
|
||||||
verbose_eval=verbose, xgb_model=model,
|
feval=feval,
|
||||||
callbacks=callbacks)
|
verbose_eval=verbose,
|
||||||
|
xgb_model=model,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
self._set_evaluation_result(evals_result)
|
self._set_evaluation_result(evals_result)
|
||||||
return self
|
return self
|
||||||
@ -911,7 +976,9 @@ class XGBModel(XGBModelBase):
|
|||||||
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
|
return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias'])
|
||||||
|
|
||||||
|
|
||||||
def _cls_predict_proba(objective: Union[str, Callable], prediction: Any, vstack: Callable) -> Any:
|
def _cls_predict_proba(
|
||||||
|
objective: Union[str, Callable], prediction: Any, vstack: Callable
|
||||||
|
) -> Any:
|
||||||
if objective == 'multi:softmax':
|
if objective == 'multi:softmax':
|
||||||
raise ValueError('multi:softmax objective does not support predict_proba,'
|
raise ValueError('multi:softmax objective does not support predict_proba,'
|
||||||
' use `multi:softprob` or `binary:logistic` instead.')
|
' use `multi:softprob` or `binary:logistic` instead.')
|
||||||
@ -931,8 +998,8 @@ def _cls_predict_proba(objective: Union[str, Callable], prediction: Any, vstack:
|
|||||||
n_estimators : int
|
n_estimators : int
|
||||||
Number of boosting rounds.
|
Number of boosting rounds.
|
||||||
use_label_encoder : bool
|
use_label_encoder : bool
|
||||||
(Deprecated) Use the label encoder from scikit-learn to encode the labels. For new code,
|
(Deprecated) Use the label encoder from scikit-learn to encode the labels. For new
|
||||||
we recommend that you set this parameter to False.
|
code, we recommend that you set this parameter to False.
|
||||||
''')
|
''')
|
||||||
class XGBClassifier(XGBModel, XGBClassifierBase):
|
class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||||
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
|
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
|
||||||
@ -955,45 +1022,55 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
verbose=True,
|
verbose=True,
|
||||||
xgb_model=None,
|
xgb_model=None,
|
||||||
sample_weight_eval_set=None,
|
sample_weight_eval_set=None,
|
||||||
|
base_margin_eval_set=None,
|
||||||
feature_weights=None,
|
feature_weights=None,
|
||||||
callbacks=None
|
callbacks=None
|
||||||
):
|
):
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ,too-many-statements
|
# pylint: disable = attribute-defined-outside-init,too-many-statements
|
||||||
|
|
||||||
can_use_label_encoder = True
|
can_use_label_encoder = True
|
||||||
label_encoding_check_error = (
|
label_encoding_check_error = (
|
||||||
'The label must consist of integer labels of form 0, 1, 2, ..., [num_class - 1].')
|
"The label must consist of integer "
|
||||||
|
"labels of form 0, 1, 2, ..., [num_class - 1]."
|
||||||
|
)
|
||||||
label_encoder_deprecation_msg = (
|
label_encoder_deprecation_msg = (
|
||||||
'The use of label encoder in XGBClassifier is deprecated and will be ' +
|
"The use of label encoder in XGBClassifier is deprecated and will be "
|
||||||
'removed in a future release. To remove this warning, do the ' +
|
"removed in a future release. To remove this warning, do the "
|
||||||
'following: 1) Pass option use_label_encoder=False when constructing ' +
|
"following: 1) Pass option use_label_encoder=False when constructing "
|
||||||
'XGBClassifier object; and 2) Encode your labels (y) as integers ' +
|
"XGBClassifier object; and 2) Encode your labels (y) as integers "
|
||||||
'starting with 0, i.e. 0, 1, 2, ..., [num_class - 1].')
|
"starting with 0, i.e. 0, 1, 2, ..., [num_class - 1]."
|
||||||
|
)
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
if _is_cudf_df(y) or _is_cudf_ser(y):
|
if _is_cudf_df(y) or _is_cudf_ser(y):
|
||||||
import cupy as cp # pylint: disable=E0401
|
import cupy as cp # pylint: disable=E0401
|
||||||
|
|
||||||
self.classes_ = cp.unique(y.values)
|
self.classes_ = cp.unique(y.values)
|
||||||
self.n_classes_ = len(self.classes_)
|
self.n_classes_ = len(self.classes_)
|
||||||
can_use_label_encoder = False
|
can_use_label_encoder = False
|
||||||
expected_classes = cp.arange(self.n_classes_)
|
expected_classes = cp.arange(self.n_classes_)
|
||||||
if (self.classes_.shape != expected_classes.shape or
|
if (
|
||||||
not (self.classes_ == expected_classes).all()):
|
self.classes_.shape != expected_classes.shape
|
||||||
|
or not (self.classes_ == expected_classes).all()
|
||||||
|
):
|
||||||
raise ValueError(label_encoding_check_error)
|
raise ValueError(label_encoding_check_error)
|
||||||
elif _is_cupy_array(y):
|
elif _is_cupy_array(y):
|
||||||
import cupy as cp # pylint: disable=E0401
|
import cupy as cp # pylint: disable=E0401
|
||||||
|
|
||||||
self.classes_ = cp.unique(y)
|
self.classes_ = cp.unique(y)
|
||||||
self.n_classes_ = len(self.classes_)
|
self.n_classes_ = len(self.classes_)
|
||||||
can_use_label_encoder = False
|
can_use_label_encoder = False
|
||||||
expected_classes = cp.arange(self.n_classes_)
|
expected_classes = cp.arange(self.n_classes_)
|
||||||
if (self.classes_.shape != expected_classes.shape or
|
if (
|
||||||
not (self.classes_ == expected_classes).all()):
|
self.classes_.shape != expected_classes.shape
|
||||||
|
or not (self.classes_ == expected_classes).all()
|
||||||
|
):
|
||||||
raise ValueError(label_encoding_check_error)
|
raise ValueError(label_encoding_check_error)
|
||||||
else:
|
else:
|
||||||
self.classes_ = np.unique(y)
|
self.classes_ = np.unique(y)
|
||||||
self.n_classes_ = len(self.classes_)
|
self.n_classes_ = len(self.classes_)
|
||||||
if not self.use_label_encoder and (
|
if not self.use_label_encoder and (
|
||||||
not np.array_equal(self.classes_, np.arange(self.n_classes_))):
|
not np.array_equal(self.classes_, np.arange(self.n_classes_))
|
||||||
|
):
|
||||||
raise ValueError(label_encoding_check_error)
|
raise ValueError(label_encoding_check_error)
|
||||||
|
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
@ -1008,8 +1085,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
if self.n_classes_ > 2:
|
if self.n_classes_ > 2:
|
||||||
# Switch to using a multiclass objective in the underlying
|
# Switch to using a multiclass objective in the underlying
|
||||||
# XGB instance
|
# XGB instance
|
||||||
params['objective'] = 'multi:softprob'
|
params["objective"] = "multi:softprob"
|
||||||
params['num_class'] = self.n_classes_
|
params["num_class"] = self.n_classes_
|
||||||
|
|
||||||
if self.use_label_encoder:
|
if self.use_label_encoder:
|
||||||
if not can_use_label_encoder:
|
if not can_use_label_encoder:
|
||||||
@ -1021,34 +1098,45 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
self._le = XGBoostLabelEncoder().fit(y)
|
self._le = XGBoostLabelEncoder().fit(y)
|
||||||
label_transform = self._le.transform
|
label_transform = self._le.transform
|
||||||
else:
|
else:
|
||||||
label_transform = (lambda x: x)
|
label_transform = lambda x: x
|
||||||
|
|
||||||
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
||||||
if len(X.shape) != 2:
|
if len(X.shape) != 2:
|
||||||
# Simply raise an error here since there might be many
|
# Simply raise an error here since there might be many
|
||||||
# different ways of reshaping
|
# different ways of reshaping
|
||||||
raise ValueError(
|
raise ValueError("Please reshape the input data X into 2-dimensional matrix.")
|
||||||
'Please reshape the input data X into 2-dimensional matrix.')
|
|
||||||
|
|
||||||
train_dmatrix, evals = self._wrap_evaluation_matrices(
|
train_dmatrix, evals = _wrap_evaluation_matrices(
|
||||||
X, y, group=None, qid=None,
|
missing=self.missing,
|
||||||
|
X=X,
|
||||||
|
y=y,
|
||||||
|
group=None,
|
||||||
|
qid=None,
|
||||||
sample_weight=sample_weight,
|
sample_weight=sample_weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
feature_weights=feature_weights,
|
feature_weights=feature_weights,
|
||||||
eval_set=eval_set,
|
eval_set=eval_set,
|
||||||
sample_weight_eval_set=sample_weight_eval_set,
|
sample_weight_eval_set=sample_weight_eval_set,
|
||||||
|
base_margin_eval_set=base_margin_eval_set,
|
||||||
eval_group=None,
|
eval_group=None,
|
||||||
eval_qid=None,
|
eval_qid=None,
|
||||||
label_transform=label_transform
|
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||||
|
label_transform=label_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._Booster = train(params, train_dmatrix,
|
self._Booster = train(
|
||||||
|
params,
|
||||||
|
train_dmatrix,
|
||||||
self.get_num_boosting_rounds(),
|
self.get_num_boosting_rounds(),
|
||||||
evals=evals,
|
evals=evals,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, obj=obj, feval=feval,
|
evals_result=evals_result,
|
||||||
verbose_eval=verbose, xgb_model=model,
|
obj=obj,
|
||||||
callbacks=callbacks)
|
feval=feval,
|
||||||
|
verbose_eval=verbose,
|
||||||
|
xgb_model=model,
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
if not callable(self.objective):
|
if not callable(self.objective):
|
||||||
self.objective = params["objective"]
|
self.objective = params["objective"]
|
||||||
@ -1178,8 +1266,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
n_estimators : int
|
n_estimators : int
|
||||||
Number of trees in random forest to fit.
|
Number of trees in random forest to fit.
|
||||||
use_label_encoder : bool
|
use_label_encoder : bool
|
||||||
(Deprecated) Use the label encoder from scikit-learn to encode the labels. For new code,
|
(Deprecated) Use the label encoder from scikit-learn to encode the labels. For new
|
||||||
we recommend that you set this parameter to False.
|
code, we recommend that you set this parameter to False.
|
||||||
''')
|
''')
|
||||||
class XGBRFClassifier(XGBClassifier):
|
class XGBRFClassifier(XGBClassifier):
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
@ -1285,11 +1373,10 @@ class XGBRFRegressor(XGBRegressor):
|
|||||||
class XGBRanker(XGBModel, XGBRankerMixIn):
|
class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||||
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def __init__(self, *, objective='rank:pairwise', **kwargs):
|
def __init__(self, *, objective="rank:pairwise", **kwargs):
|
||||||
super().__init__(objective=objective, **kwargs)
|
super().__init__(objective=objective, **kwargs)
|
||||||
if callable(self.objective):
|
if callable(self.objective):
|
||||||
raise ValueError(
|
raise ValueError("custom objective function not supported by XGBRanker")
|
||||||
"custom objective function not supported by XGBRanker")
|
|
||||||
if "rank:" not in self.objective:
|
if "rank:" not in self.objective:
|
||||||
raise ValueError("please use XGBRanker for ranking task")
|
raise ValueError("please use XGBRanker for ranking task")
|
||||||
|
|
||||||
@ -1304,22 +1391,23 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
sample_weight=None,
|
sample_weight=None,
|
||||||
base_margin=None,
|
base_margin=None,
|
||||||
eval_set=None,
|
eval_set=None,
|
||||||
sample_weight_eval_set=None,
|
|
||||||
eval_group=None,
|
eval_group=None,
|
||||||
eval_qid=None,
|
eval_qid=None,
|
||||||
eval_metric=None,
|
eval_metric=None,
|
||||||
early_stopping_rounds=None,
|
early_stopping_rounds=None,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
xgb_model=None,
|
xgb_model: Optional[Union[Booster, str, XGBModel]] = None,
|
||||||
|
sample_weight_eval_set=None,
|
||||||
|
base_margin_eval_set=None,
|
||||||
feature_weights=None,
|
feature_weights=None,
|
||||||
callbacks=None
|
callbacks=None
|
||||||
) -> "XGBRanker":
|
) -> "XGBRanker":
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
"""Fit gradient boosting ranker
|
"""Fit gradient boosting ranker
|
||||||
|
|
||||||
Note that calling ``fit()`` multiple times will cause the model object to be re-fit from
|
Note that calling ``fit()`` multiple times will cause the model object to be
|
||||||
scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model``
|
re-fit from scratch. To resume training from a previous checkpoint, explicitly
|
||||||
argument.
|
pass ``xgb_model`` argument.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -1343,24 +1431,12 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
data point). This is because we only care about the relative ordering of
|
data point). This is because we only care about the relative ordering of
|
||||||
data points within each group, so it doesn't make sense to assign weights
|
data points within each group, so it doesn't make sense to assign weights
|
||||||
to individual data points.
|
to individual data points.
|
||||||
|
|
||||||
base_margin : array_like
|
base_margin : array_like
|
||||||
Global bias for each instance.
|
Global bias for each instance.
|
||||||
eval_set : list, optional
|
eval_set : list, optional
|
||||||
A list of (X, y) tuple pairs to use as validation sets, for which
|
A list of (X, y) tuple pairs to use as validation sets, for which
|
||||||
metrics will be computed.
|
metrics will be computed.
|
||||||
Validation metrics will help us track the performance of the model.
|
Validation metrics will help us track the performance of the model.
|
||||||
sample_weight_eval_set : list, optional
|
|
||||||
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
|
||||||
group weights on the i-th validation set.
|
|
||||||
|
|
||||||
.. note:: Weights are per-group for ranking tasks
|
|
||||||
|
|
||||||
In ranking task, one weight is assigned to each query group (not each
|
|
||||||
data point). This is because we only care about the relative ordering of
|
|
||||||
data points within each group, so it doesn't make sense to assign
|
|
||||||
weights to individual data points.
|
|
||||||
|
|
||||||
eval_group : list of arrays, optional
|
eval_group : list of arrays, optional
|
||||||
A list in which ``eval_group[i]`` is the list containing the sizes of all
|
A list in which ``eval_group[i]`` is the list containing the sizes of all
|
||||||
query groups in the ``i``-th pair in **eval_set**.
|
query groups in the ``i``-th pair in **eval_set**.
|
||||||
@ -1374,22 +1450,34 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
to use. The custom evaluation metric is not yet supported for the ranker.
|
to use. The custom evaluation metric is not yet supported for the ranker.
|
||||||
early_stopping_rounds : int
|
early_stopping_rounds : int
|
||||||
Activates early stopping. Validation metric needs to improve at least once in
|
Activates early stopping. Validation metric needs to improve at least once in
|
||||||
every **early_stopping_rounds** round(s) to continue training.
|
every **early_stopping_rounds** round(s) to continue training. Requires at
|
||||||
Requires at least one item in **eval_set**.
|
least one item in **eval_set**.
|
||||||
The method returns the model from the last iteration (not the best one).
|
The method returns the model from the last iteration (not the best one). If
|
||||||
If there's more than one item in **eval_set**, the last entry will be used
|
there's more than one item in **eval_set**, the last entry will be used for
|
||||||
for early stopping.
|
early stopping.
|
||||||
If there's more than one metric in **eval_metric**, the last metric
|
If there's more than one metric in **eval_metric**, the last metric will be
|
||||||
will be used for early stopping.
|
used for early stopping.
|
||||||
If early stopping occurs, the model will have three additional
|
If early stopping occurs, the model will have three additional fields:
|
||||||
fields: ``clf.best_score``, ``clf.best_iteration`` and
|
``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``.
|
||||||
``clf.best_ntree_limit``.
|
|
||||||
verbose : bool
|
verbose : bool
|
||||||
If `verbose` and an evaluation set is used, writes the evaluation
|
If `verbose` and an evaluation set is used, writes the evaluation metric
|
||||||
metric measured on the validation set to stderr.
|
measured on the validation set to stderr.
|
||||||
xgb_model : Union[str, Booster, XGBModel]
|
xgb_model :
|
||||||
file name of stored XGBoost model or 'Booster' instance XGBoost
|
file name of stored XGBoost model or 'Booster' instance XGBoost model to be
|
||||||
model to be loaded before training (allows training continuation).
|
loaded before training (allows training continuation).
|
||||||
|
sample_weight_eval_set : list, optional
|
||||||
|
A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of
|
||||||
|
group weights on the i-th validation set.
|
||||||
|
|
||||||
|
.. note:: Weights are per-group for ranking tasks
|
||||||
|
|
||||||
|
In ranking task, one weight is assigned to each query group (not each
|
||||||
|
data point). This is because we only care about the relative ordering of
|
||||||
|
data points within each group, so it doesn't make sense to assign
|
||||||
|
weights to individual data points.
|
||||||
|
base_margin_eval_set : list, optional
|
||||||
|
A list of the form [M_1, M_2, ..., M_n], where each M_i is an array like
|
||||||
|
object storing base margin for the i-th validation set.
|
||||||
feature_weights: array_like
|
feature_weights: array_like
|
||||||
Weight for each feature, defines the probability of each feature being
|
Weight for each feature, defines the probability of each feature being
|
||||||
selected when colsample is being used. All values must be greater than 0,
|
selected when colsample is being used. All values must be greater than 0,
|
||||||
@ -1406,6 +1494,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
save_best=True)]
|
save_best=True)]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# check if group information is provided
|
||||||
if group is None and qid is None:
|
if group is None and qid is None:
|
||||||
raise ValueError("group or qid is required for ranking task")
|
raise ValueError("group or qid is required for ranking task")
|
||||||
|
|
||||||
@ -1413,24 +1502,10 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
if eval_group is None and eval_qid is None:
|
if eval_group is None and eval_qid is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"eval_group or eval_qid is required if eval_set is not None")
|
"eval_group or eval_qid is required if eval_set is not None")
|
||||||
if (
|
train_dmatrix, evals = _wrap_evaluation_matrices(
|
||||||
(eval_group is not None and len(eval_group) != len(eval_set)) and
|
missing=self.missing,
|
||||||
(eval_qid is not None and len(eval_qid) != len(eval_set))
|
X=X,
|
||||||
):
|
y=y,
|
||||||
raise ValueError(
|
|
||||||
"length of eval_group or eval_qid should match that of eval_set"
|
|
||||||
)
|
|
||||||
for i in range(len(eval_set)):
|
|
||||||
if (
|
|
||||||
(eval_group is not None and eval_group[i] is not None) and
|
|
||||||
(eval_qid is not None and eval_qid[i] is not None)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Only one of the qid or group should be provided."
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dmatrix, evals = self._wrap_evaluation_matrices(
|
|
||||||
X, y,
|
|
||||||
group=group,
|
group=group,
|
||||||
qid=qid,
|
qid=qid,
|
||||||
sample_weight=sample_weight,
|
sample_weight=sample_weight,
|
||||||
@ -1438,54 +1513,32 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
feature_weights=feature_weights,
|
feature_weights=feature_weights,
|
||||||
eval_set=eval_set,
|
eval_set=eval_set,
|
||||||
sample_weight_eval_set=sample_weight_eval_set,
|
sample_weight_eval_set=sample_weight_eval_set,
|
||||||
|
base_margin_eval_set=base_margin_eval_set,
|
||||||
eval_group=eval_group,
|
eval_group=eval_group,
|
||||||
eval_qid=eval_qid
|
eval_qid=eval_qid,
|
||||||
)
|
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||||
|
|
||||||
if qid is not None:
|
|
||||||
train_dmatrix.set_info(qid=qid)
|
|
||||||
elif group is not None:
|
|
||||||
train_dmatrix.set_info(group=group)
|
|
||||||
else:
|
|
||||||
raise ValueError("Either qid or group should be provided for ranking task.")
|
|
||||||
|
|
||||||
if evals is not None:
|
|
||||||
for i, e in enumerate(evals):
|
|
||||||
if eval_qid is not None and eval_qid[i] is not None:
|
|
||||||
assert eval_group is None or eval_group[i] is None, (
|
|
||||||
'Only one of the eval_qid or eval_group for each evaluation '
|
|
||||||
'dataset should be provided.'
|
|
||||||
)
|
|
||||||
e[0].set_info(qid=qid)
|
|
||||||
elif eval_group is not None and eval_group[i] is not None:
|
|
||||||
e[0].set_info(group=eval_group[i])
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
'Either eval_qid or eval_group for each evaluation dataset should'
|
|
||||||
' be provided for ranking task.'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
|
|
||||||
feval = eval_metric if callable(eval_metric) else None
|
model, feval, params = self._configure_fit(xgb_model, eval_metric, params)
|
||||||
if eval_metric is not None:
|
if callable(feval):
|
||||||
if callable(eval_metric):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Custom evaluation metric is not yet supported for XGBRanker.')
|
'Custom evaluation metric is not yet supported for XGBRanker.'
|
||||||
params.update({'eval_metric': eval_metric})
|
)
|
||||||
if hasattr(xgb_model, '_Booster'):
|
|
||||||
# Handle the case when xgb_model is a sklearn model object
|
|
||||||
xgb_model = xgb_model._Booster # pylint: disable=protected-access
|
|
||||||
|
|
||||||
self._Booster = train(params, train_dmatrix,
|
self._Booster = train(
|
||||||
self.get_num_boosting_rounds(),
|
params, train_dmatrix,
|
||||||
|
self.n_estimators,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals=evals,
|
evals=evals,
|
||||||
evals_result=evals_result, feval=feval,
|
evals_result=evals_result, feval=feval,
|
||||||
verbose_eval=verbose, xgb_model=xgb_model,
|
verbose_eval=verbose, xgb_model=model,
|
||||||
callbacks=callbacks)
|
callbacks=callbacks
|
||||||
|
)
|
||||||
|
|
||||||
self.objective = params["objective"]
|
self.objective = params["objective"]
|
||||||
|
|
||||||
self._set_evaluation_result(evals_result)
|
self._set_evaluation_result(evals_result)
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -314,6 +314,14 @@ class TestDistributedGPU:
|
|||||||
for i in range(len(ddqdm_names)):
|
for i in range(len(ddqdm_names)):
|
||||||
assert ddqdm_names[i] == dqdm_names[i]
|
assert ddqdm_names[i] == dqdm_names[i]
|
||||||
|
|
||||||
|
sig = OrderedDict(signature(xgb.XGBRanker.fit).parameters)
|
||||||
|
ranker_names = list(sig.keys())
|
||||||
|
sig = OrderedDict(signature(xgb.dask.DaskXGBRanker.fit).parameters)
|
||||||
|
dranker_names = list(sig.keys())
|
||||||
|
|
||||||
|
for rn, drn in zip(ranker_names, dranker_names):
|
||||||
|
assert rn == drn
|
||||||
|
|
||||||
def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None:
|
def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None:
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows")
|
pytest.skip("Skipping dask tests on Windows")
|
||||||
|
|||||||
@ -17,7 +17,7 @@ import subprocess
|
|||||||
import hypothesis
|
import hypothesis
|
||||||
from hypothesis import given, settings, note, HealthCheck
|
from hypothesis import given, settings, note, HealthCheck
|
||||||
from test_updaters import hist_parameter_strategy, exact_parameter_strategy
|
from test_updaters import hist_parameter_strategy, exact_parameter_strategy
|
||||||
from test_with_sklearn import run_feature_weights
|
from test_with_sklearn import run_feature_weights, run_data_initialization
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
@ -176,6 +176,22 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
|
|||||||
|
|
||||||
assert np.all(predictions_1.compute() == predictions_2.compute())
|
assert np.all(predictions_1.compute() == predictions_2.compute())
|
||||||
|
|
||||||
|
margined = xgb.dask.DaskXGBClassifier(n_estimators=4)
|
||||||
|
margined.fit(
|
||||||
|
X=X, y=y, base_margin=margin, eval_set=[(X, y)], base_margin_eval_set=[margin]
|
||||||
|
)
|
||||||
|
|
||||||
|
unmargined = xgb.dask.DaskXGBClassifier(n_estimators=4)
|
||||||
|
unmargined.fit(X=X, y=y, eval_set=[(X, y)], base_margin=margin)
|
||||||
|
|
||||||
|
margined_res = margined.evals_result()['validation_0']['logloss']
|
||||||
|
unmargined_res = unmargined.evals_result()['validation_0']['logloss']
|
||||||
|
|
||||||
|
assert len(margined_res) == len(unmargined_res)
|
||||||
|
for i in range(len(margined_res)):
|
||||||
|
# margined is correct one, so smaller error.
|
||||||
|
assert margined_res[i] < unmargined_res[i]
|
||||||
|
|
||||||
|
|
||||||
def test_dask_missing_value_reg(client: "Client") -> None:
|
def test_dask_missing_value_reg(client: "Client") -> None:
|
||||||
X_0 = np.ones((20 // 2, kCols))
|
X_0 = np.ones((20 // 2, kCols))
|
||||||
@ -955,7 +971,7 @@ class TestWithDask:
|
|||||||
results_native['validation_0']['rmse'])
|
results_native['validation_0']['rmse'])
|
||||||
tm.non_increasing(results_native['validation_0']['rmse'])
|
tm.non_increasing(results_native['validation_0']['rmse'])
|
||||||
|
|
||||||
def test_data_initialization(self) -> None:
|
def test_no_duplicated_partition(self) -> None:
|
||||||
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
||||||
generate unnecessary copies of data.
|
generate unnecessary copies of data.
|
||||||
|
|
||||||
@ -995,6 +1011,13 @@ class TestWithDask:
|
|||||||
# Subtract the on disk resource from each worker
|
# Subtract the on disk resource from each worker
|
||||||
assert cnt - n_workers == n_partitions
|
assert cnt - n_workers == n_partitions
|
||||||
|
|
||||||
|
def test_data_initialization(self, client: "Client") -> None:
|
||||||
|
"""assert that we don't create duplicated DMatrix"""
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
X, y = load_digits(return_X_y=True)
|
||||||
|
X, y = dd.from_array(X, chunksize=32), dd.from_array(y, chunksize=32)
|
||||||
|
run_data_initialization(xgb.dask.DaskDMatrix, xgb.dask.DaskXGBClassifier, X, y)
|
||||||
|
|
||||||
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
|
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
|
||||||
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||||
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
|||||||
@ -717,13 +717,13 @@ def test_validation_weights_xgbmodel():
|
|||||||
assert all((logloss_with_weights[i] != logloss_without_weights[i]
|
assert all((logloss_with_weights[i] != logloss_without_weights[i]
|
||||||
for i in [0, 1]))
|
for i in [0, 1]))
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(ValueError):
|
||||||
# length of eval set and sample weight doesn't match.
|
# length of eval set and sample weight doesn't match.
|
||||||
clf.fit(X_train, y_train, sample_weight=weights_train,
|
clf.fit(X_train, y_train, sample_weight=weights_train,
|
||||||
eval_set=[(X_train, y_train), (X_test, y_test)],
|
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||||
sample_weight_eval_set=[weights_train])
|
sample_weight_eval_set=[weights_train])
|
||||||
|
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(ValueError):
|
||||||
cls = xgb.XGBClassifier()
|
cls = xgb.XGBClassifier()
|
||||||
cls.fit(X_train, y_train, sample_weight=weights_train,
|
cls.fit(X_train, y_train, sample_weight=weights_train,
|
||||||
eval_set=[(X_train, y_train), (X_test, y_test)],
|
eval_set=[(X_train, y_train), (X_test, y_test)],
|
||||||
@ -1118,19 +1118,9 @@ def run_boost_from_prediction(tree_method):
|
|||||||
assert np.all(predictions_1 == predictions_2)
|
assert np.all(predictions_1 == predictions_2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
|
||||||
def test_boost_from_prediction_hist():
|
def test_boost_from_prediction(tree_method):
|
||||||
run_boost_from_prediction('hist')
|
run_boost_from_prediction(tree_method)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
|
||||||
def test_boost_from_prediction_approx():
|
|
||||||
run_boost_from_prediction('approx')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
|
||||||
def test_boost_from_prediction_exact():
|
|
||||||
run_boost_from_prediction('exact')
|
|
||||||
|
|
||||||
|
|
||||||
def test_estimator_type():
|
def test_estimator_type():
|
||||||
@ -1154,3 +1144,32 @@ def test_estimator_type():
|
|||||||
|
|
||||||
cls = xgb.XGBClassifier()
|
cls = xgb.XGBClassifier()
|
||||||
cls.load_model(path) # no error
|
cls.load_model(path) # no error
|
||||||
|
|
||||||
|
|
||||||
|
def run_data_initialization(DMatrix, model, X, y):
|
||||||
|
"""Assert that we don't create duplicated DMatrix."""
|
||||||
|
|
||||||
|
old_init = DMatrix.__init__
|
||||||
|
count = [0]
|
||||||
|
|
||||||
|
def new_init(self, **kwargs):
|
||||||
|
count[0] += 1
|
||||||
|
return old_init(self, **kwargs)
|
||||||
|
|
||||||
|
DMatrix.__init__ = new_init
|
||||||
|
model(n_estimators=1).fit(X, y, eval_set=[(X, y)])
|
||||||
|
|
||||||
|
assert count[0] == 1
|
||||||
|
count[0] = 0 # only 1 DMatrix is created.
|
||||||
|
|
||||||
|
y_copy = y.copy()
|
||||||
|
model(n_estimators=1).fit(X, y, eval_set=[(X, y_copy)])
|
||||||
|
assert count[0] == 2 # a different Python object is considered different
|
||||||
|
|
||||||
|
DMatrix.__init__ = old_init
|
||||||
|
|
||||||
|
|
||||||
|
def test_data_initialization():
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
X, y = load_digits(return_X_y=True)
|
||||||
|
run_data_initialization(xgb.DMatrix, xgb.XGBClassifier, X, y)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user