diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index ee16467af..2cae5fc5c 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -38,7 +38,8 @@ from .core import Objective, Metric from .core import _deprecate_positional_args from .training import train as worker_train from .tracker import RabitTracker, get_host_ip -from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase, _objective_decorator +from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase +from .sklearn import _wrap_evaluation_matrices, _objective_decorator from .sklearn import XGBRankerMixIn from .sklearn import xgboost_model_doc from .sklearn import _cls_predict_proba @@ -588,7 +589,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix): weight: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None, missing: float = None, - silent: bool = False, + silent: bool = False, # disable=unused-argument feature_names: Optional[Union[str, List[str]]] = None, feature_types: Optional[Union[Any, List[Any]]] = None, max_bin: int = 256, @@ -1292,44 +1293,24 @@ def inplace_predict( missing=missing) -async def _evaluation_matrices( - client: "distributed.Client", - validation_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], - sample_weight: Optional[List[_DaskCollection]], - sample_qid: Optional[List[_DaskCollection]], - missing: float -) -> Optional[List[Tuple[DaskDMatrix, str]]]: - ''' - Parameters - ---------- - validation_set: list of tuples - Each tuple contains a validation dataset including input X and label y. - E.g.: - - .. code-block:: python - - [(X_0, y_0), (X_1, y_1), ... ] - - sample_weights: list of arrays - The weight vector for validation data. - - Returns - ------- - evals: list of validation DMatrix - ''' - evals: Optional[List[Tuple[DaskDMatrix, str]]] = [] - if validation_set is not None: - assert isinstance(validation_set, list) - for i, e in enumerate(validation_set): - w = sample_weight[i] if sample_weight is not None else None - qid = sample_qid[i] if sample_qid is not None else None - dmat = await DaskDMatrix(client=client, data=e[0], label=e[1], - weight=w, missing=missing, qid=qid) - assert isinstance(evals, list) - evals.append((dmat, 'validation_{}'.format(i))) - else: - evals = None - return evals +async def _async_wrap_evaluation_matrices( + client: "distributed.Client", **kwargs: Any +) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]: + """A switch function for async environment.""" + def _inner(**kwargs: Any) -> DaskDMatrix: + m = DaskDMatrix(client=client, **kwargs) + return m + train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs) + train_dmatrix = await train_dmatrix + if evals is None: + return train_dmatrix, evals + awaited = [] + for e in evals: + if e[0] is train_dmatrix: # already awaited + awaited.append(e) + continue + awaited.append((await e[0], e[1])) + return train_dmatrix, awaited class DaskScikitLearnBase(XGBModel): @@ -1337,7 +1318,6 @@ class DaskScikitLearnBase(XGBModel): _client = None - @_deprecate_positional_args async def _predict_async( self, data: _DaskCollection, output_margin: bool = False, @@ -1404,25 +1384,30 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], eval_metric: Optional[Union[str, List[str], Metric]], sample_weight_eval_set: Optional[List[_DaskCollection]], + base_margin_eval_set: Optional[List[_DaskCollection]], early_stopping_rounds: int, verbose: bool, xgb_model: Optional[Union[Booster, XGBModel]], feature_weights: Optional[_DaskCollection], callbacks: Optional[List[TrainingCallback]], ) -> _DaskCollection: - dtrain = await DaskDMatrix( + params = self.get_xgb_params() + dtrain, evals = await _async_wrap_evaluation_matrices( client=self.client, - data=X, - label=y, - weight=sample_weight, + X=X, + y=y, + group=None, + qid=None, + sample_weight=sample_weight, base_margin=base_margin, feature_weights=feature_weights, + eval_set=eval_set, + sample_weight_eval_set=sample_weight_eval_set, + base_margin_eval_set=base_margin_eval_set, + eval_group=None, + eval_qid=None, missing=self.missing, ) - params = self.get_xgb_params() - evals = await _evaluation_matrices( - self.client, eval_set, sample_weight_eval_set, None, self.missing - ) if callable(self.objective): obj = _objective_decorator(self.objective) @@ -1449,7 +1434,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): self.evals_result_ = results["history"] return self - # pylint: disable=missing-docstring + # pylint: disable=missing-docstring, disable=unused-argument @_deprecate_positional_args def fit( self, @@ -1464,25 +1449,13 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): verbose: bool = True, xgb_model: Optional[Union[Booster, XGBModel]] = None, sample_weight_eval_set: Optional[List[_DaskCollection]] = None, + base_margin_eval_set: Optional[List[_DaskCollection]] = None, feature_weights: Optional[_DaskCollection] = None, callbacks: Optional[List[TrainingCallback]] = None ) -> "DaskXGBRegressor": _assert_dask_support() - return self.client.sync( - self._fit_async, - X=X, - y=y, - sample_weight=sample_weight, - base_margin=base_margin, - eval_set=eval_set, - eval_metric=eval_metric, - sample_weight_eval_set=sample_weight_eval_set, - early_stopping_rounds=early_stopping_rounds, - verbose=verbose, - xgb_model=xgb_model, - feature_weights=feature_weights, - callbacks=callbacks, - ) + args = {k: v for k, v in locals().items() if k != "self"} + return self.client.sync(self._fit_async, **args) @xgboost_model_doc( @@ -1497,20 +1470,30 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], eval_metric: Optional[Union[str, List[str], Metric]], sample_weight_eval_set: Optional[List[_DaskCollection]], + base_margin_eval_set: Optional[List[_DaskCollection]], early_stopping_rounds: int, verbose: bool, xgb_model: Optional[Union[Booster, XGBModel]], feature_weights: Optional[_DaskCollection], callbacks: Optional[List[TrainingCallback]] ) -> "DaskXGBClassifier": - dtrain = await DaskDMatrix(client=self.client, - data=X, - label=y, - weight=sample_weight, - base_margin=base_margin, - feature_weights=feature_weights, - missing=self.missing) params = self.get_xgb_params() + dtrain, evals = await _async_wrap_evaluation_matrices( + self.client, + X=X, + y=y, + group=None, + qid=None, + sample_weight=sample_weight, + base_margin=base_margin, + feature_weights=feature_weights, + eval_set=eval_set, + sample_weight_eval_set=sample_weight_eval_set, + base_margin_eval_set=base_margin_eval_set, + eval_group=None, + eval_qid=None, + missing=self.missing, + ) # pylint: disable=attribute-defined-outside-init if isinstance(y, (da.Array)): @@ -1525,11 +1508,6 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): else: params["objective"] = "binary:logistic" - evals = await _evaluation_matrices(self.client, eval_set, - sample_weight_eval_set, - None, - self.missing) - if callable(self.objective): obj = _objective_decorator(self.objective) else: @@ -1561,6 +1539,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): self.evals_result_ = results['history'] return self + # pylint: disable=unused-argument @_deprecate_positional_args def fit( self, @@ -1575,25 +1554,13 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): verbose: bool = True, xgb_model: Optional[Union[Booster, XGBModel]] = None, sample_weight_eval_set: Optional[List[_DaskCollection]] = None, + base_margin_eval_set: Optional[List[_DaskCollection]] = None, feature_weights: Optional[_DaskCollection] = None, callbacks: Optional[List[TrainingCallback]] = None ) -> "DaskXGBClassifier": _assert_dask_support() - return self.client.sync( - self._fit_async, - X=X, - y=y, - sample_weight=sample_weight, - base_margin=base_margin, - eval_set=eval_set, - eval_metric=eval_metric, - sample_weight_eval_set=sample_weight_eval_set, - early_stopping_rounds=early_stopping_rounds, - verbose=verbose, - xgb_model=xgb_model, - feature_weights=feature_weights, - callbacks=callbacks, - ) + args = {k: v for k, v in locals().items() if k != 'self'} + return self.client.sync(self._fit_async, **args) async def _predict_proba_async( self, @@ -1613,7 +1580,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): output_margin=output_margin) return _cls_predict_proba(self.objective, pred_probs, da.vstack) - # pylint: disable=missing-docstring + # pylint: disable=missing-function-docstring def predict_proba( self, X: _DaskCollection, @@ -1632,6 +1599,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): output_margin=output_margin, base_margin=base_margin ) + predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__ async def _predict_async( self, data: _DaskCollection, @@ -1673,11 +1641,14 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): self, X: _DaskCollection, y: _DaskCollection, + group: Optional[_DaskCollection], qid: Optional[_DaskCollection], sample_weight: Optional[_DaskCollection], base_margin: Optional[_DaskCollection], eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]], sample_weight_eval_set: Optional[List[_DaskCollection]], + base_margin_eval_set: Optional[List[_DaskCollection]], + eval_group: Optional[List[_DaskCollection]], eval_qid: Optional[List[_DaskCollection]], eval_metric: Optional[Union[str, List[str], Metric]], early_stopping_rounds: int, @@ -1686,22 +1657,26 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): feature_weights: Optional[_DaskCollection], callbacks: Optional[List[TrainingCallback]], ) -> "DaskXGBRanker": - dtrain = await DaskDMatrix( - client=self.client, - data=X, - label=y, + msg = "Use `qid` instead of `group` on dask interface." + if not (group is None and eval_group is None): + raise ValueError(msg) + if qid is None: + raise ValueError("`qid` is required for ranking.") + params = self.get_xgb_params() + dtrain, evals = await _async_wrap_evaluation_matrices( + self.client, + X=X, + y=y, + group=None, qid=qid, - weight=sample_weight, + sample_weight=sample_weight, base_margin=base_margin, feature_weights=feature_weights, - missing=self.missing, - ) - params = self.get_xgb_params() - evals = await _evaluation_matrices( - self.client, - eval_set, - sample_weight_eval_set, - sample_qid=eval_qid, + eval_set=eval_set, + sample_weight_eval_set=sample_weight_eval_set, + base_margin_eval_set=base_margin_eval_set, + eval_group=None, + eval_qid=eval_qid, missing=self.missing, ) if eval_metric is not None: @@ -1728,8 +1703,9 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): self.evals_result_ = results["history"] return self + # pylint: disable=unused-argument, arguments-differ @_deprecate_positional_args - def fit( # pylint: disable=arguments-differ + def fit( self, X: _DaskCollection, y: _DaskCollection, @@ -1739,39 +1715,20 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): sample_weight: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None, eval_set: Optional[List[Tuple[_DaskCollection, _DaskCollection]]] = None, - sample_weight_eval_set: Optional[List[_DaskCollection]] = None, eval_group: Optional[List[_DaskCollection]] = None, eval_qid: Optional[List[_DaskCollection]] = None, eval_metric: Optional[Union[str, List[str], Metric]] = None, early_stopping_rounds: int = None, verbose: bool = False, xgb_model: Optional[Union[XGBModel, Booster]] = None, + sample_weight_eval_set: Optional[List[_DaskCollection]] = None, + base_margin_eval_set: Optional[List[_DaskCollection]] = None, feature_weights: Optional[_DaskCollection] = None, callbacks: Optional[List[TrainingCallback]] = None ) -> "DaskXGBRanker": _assert_dask_support() - msg = "Use `qid` instead of `group` on dask interface." - if not (group is None and eval_group is None): - raise ValueError(msg) - if qid is None: - raise ValueError("`qid` is required for ranking.") - return self.client.sync( - self._fit_async, - X=X, - y=y, - qid=qid, - sample_weight=sample_weight, - base_margin=base_margin, - eval_set=eval_set, - sample_weight_eval_set=sample_weight_eval_set, - eval_qid=eval_qid, - eval_metric=eval_metric, - early_stopping_rounds=early_stopping_rounds, - verbose=verbose, - xgb_model=xgb_model, - feature_weights=feature_weights, - callbacks=callbacks, - ) + args = {k: v for k, v in locals().items() if k != 'self'} + return self.client.sync(self._fit_async, **args) # FIXME(trivialfis): arguments differ due to additional parameters like group and qid. fit.__doc__ = XGBRanker.fit.__doc__ diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index ebf552e1b..398fd63ef 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -212,6 +212,105 @@ Parameters return adddoc +def _wrap_evaluation_matrices( + missing: float, + X: Any, + y: Any, + group: Optional[Any], + qid: Optional[Any], + sample_weight: Optional[Any], + base_margin: Optional[Any], + feature_weights: Optional[Any], + eval_set: Optional[List[Tuple[Any, Any]]], + sample_weight_eval_set: Optional[List[Any]], + base_margin_eval_set: Optional[List[Any]], + eval_group: Optional[List[Any]], + eval_qid: Optional[List[Any]], + create_dmatrix: Callable, + label_transform: Callable = lambda x: x, +) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]: + """Convert array_like evaluation matrices into DMatrix. Perform validation on the way. + + """ + train_dmatrix = create_dmatrix( + data=X, + label=label_transform(y), + group=group, + qid=qid, + weight=sample_weight, + base_margin=base_margin, + feature_weights=feature_weights, + missing=missing, + ) + + def validate_or_none(meta: Optional[List], name: str) -> List: + if meta is None: + return [None] * len(eval_set) + if len(meta) != len(eval_set): + raise ValueError( + f"{name}'s length does not eqaul to `eval_set`, " + + f"expecting {len(eval_set)}, got {len(meta)}" + ) + return meta + + if eval_set is not None: + sample_weight_eval_set = validate_or_none( + sample_weight_eval_set, "sample_weight_eval_set" + ) + base_margin_eval_set = validate_or_none( + base_margin_eval_set, "base_margin_eval_set" + ) + eval_group = validate_or_none(eval_group, "eval_group") + eval_qid = validate_or_none(eval_qid, "eval_qid") + + evals = [] + for i, (valid_X, valid_y) in enumerate(eval_set): + # Skip the duplicated entry. + if all( + ( + valid_X is X, valid_y is y, + sample_weight_eval_set[i] is sample_weight, + base_margin_eval_set[i] is base_margin, + eval_group[i] is group, + eval_qid[i] is qid + ) + ): + evals.append(train_dmatrix) + else: + m = create_dmatrix( + data=valid_X, + label=label_transform(valid_y), + weight=sample_weight_eval_set[i], + group=eval_group[i], + qid=eval_qid[i], + base_margin=base_margin_eval_set[i], + missing=missing, + ) + evals.append(m) + nevals = len(evals) + eval_names = ["validation_{}".format(i) for i in range(nevals)] + evals = list(zip(evals, eval_names)) + else: + if any( + [ + meta is not None + for meta in [ + sample_weight_eval_set, + base_margin_eval_set, + eval_group, + eval_qid, + ] + ] + ): + raise ValueError( + "`eval_set` is not set but one of the other evaluation meta info is " + "not None." + ) + evals = [] + + return train_dmatrix, evals + + @xgboost_model_doc("""Implementation of the Scikit-Learn API for XGBoost.""", ['estimators', 'model', 'objective']) class XGBModel(XGBModelBase): @@ -281,69 +380,6 @@ class XGBModel(XGBModelBase): self.gpu_id = gpu_id self.validate_parameters = validate_parameters - def _wrap_evaluation_matrices( - self, X, y, - group, - qid, - sample_weight, - base_margin, - feature_weights, - eval_set, - sample_weight_eval_set, - eval_group, - eval_qid, - label_transform=lambda x: x - ) -> Tuple[DMatrix, Optional[List[Tuple[DMatrix, str]]]]: - '''Convert array_like evaluation matrices into DMatrix, group and qid are only used for - valiation but not set in this function - - ''' - if sample_weight_eval_set is not None: - assert eval_set is not None - assert len(sample_weight_eval_set) == len(eval_set) - if eval_group is not None: - assert eval_set is not None - assert len(eval_group) == len(eval_set) - - y = label_transform(y) - train_dmatrix = DMatrix(data=X, label=y, weight=sample_weight, - base_margin=base_margin, - missing=self.missing, nthread=self.n_jobs) - train_dmatrix.set_info(feature_weights=feature_weights) - - if eval_set is not None: - if sample_weight_eval_set is None: - sample_weight_eval_set = [None] * len(eval_set) - if eval_group is None: - eval_group = [None] * len(eval_set) - if eval_qid is None: - eval_qid = [None] * len(eval_set) - - evals = [] - for i, (valid_X, valid_y) in enumerate(eval_set): - # Skip the duplicated entry. - if ( - valid_X is X and - valid_y is y and - sample_weight_eval_set[i] is sample_weight and - eval_group[i] is group and - eval_qid[i] is qid - ): - evals.append(train_dmatrix) - else: - m = DMatrix(valid_X, - label=label_transform(valid_y), - missing=self.missing, weight=sample_weight_eval_set[i], - nthread=self.n_jobs) - evals.append(m) - - nevals = len(evals) - eval_names = ["validation_{}".format(i) for i in range(nevals)] - evals = list(zip(evals, eval_names)) - else: - evals = () - return train_dmatrix, evals - def _more_tags(self): '''Tags used for scikit-learn data validation.''' return {'allow_nan': True, 'no_validation': True} @@ -580,17 +616,29 @@ class XGBModel(XGBModelBase): self.evals_result_ = evals_result @_deprecate_positional_args - def fit(self, X, y, *, sample_weight=None, base_margin=None, - eval_set=None, eval_metric=None, early_stopping_rounds=None, - verbose=True, xgb_model=None, sample_weight_eval_set=None, - feature_weights=None, - callbacks=None): + def fit( + self, + X, + y, + *, + sample_weight=None, + base_margin=None, + eval_set=None, + eval_metric=None, + early_stopping_rounds=None, + verbose=True, + xgb_model: Optional[Union[Booster, str, "XGBModel"]] = None, + sample_weight_eval_set=None, + base_margin_eval_set=None, + feature_weights=None, + callbacks=None + ): # pylint: disable=invalid-name,attribute-defined-outside-init """Fit gradient boosting model. - Note that calling ``fit()`` multiple times will cause the model object to be re-fit from - scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model`` - argument. + Note that calling ``fit()`` multiple times will cause the model object to be + re-fit from scratch. To resume training from a previous checkpoint, explicitly + pass ``xgb_model`` argument. Parameters ---------- @@ -611,12 +659,11 @@ class XGBModel(XGBModelBase): doc/parameter.rst. If a list of str, should be the list of multiple built-in evaluation metrics to use. - If callable, a custom evaluation metric. The call - signature is ``func(y_predicted, y_true)`` where ``y_true`` will be a - DMatrix object such that you may need to call the ``get_label`` - method. It must return a str, value pair where the str is a name - for the evaluation and value is the value of the evaluation - function. The callable custom objective is always minimized. + If callable, a custom evaluation metric. The call signature is + ``func(y_predicted, y_true)`` where ``y_true`` will be a DMatrix object such + that you may need to call the ``get_label`` method. It must return a str, + value pair where the str is a name for the evaluation and value is the value + of the evaluation function. The callable custom objective is always minimized. early_stopping_rounds : int Activates early stopping. Validation metric needs to improve at least once in every **early_stopping_rounds** round(s) to continue training. @@ -629,14 +676,17 @@ class XGBModel(XGBModelBase): If early stopping occurs, the model will have three additional fields: ``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``. verbose : bool - If `verbose` and an evaluation set is used, writes the evaluation - metric measured on the validation set to stderr. - xgb_model : Union[str, Booster, XGBModel] + If `verbose` and an evaluation set is used, writes the evaluation metric + measured on the validation set to stderr. + xgb_model : file name of stored XGBoost model or 'Booster' instance XGBoost model to be loaded before training (allows training continuation). sample_weight_eval_set : list, optional - A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of - instance weights on the i-th validation set. + A list of the form [L_1, L_2, ..., L_n], where each L_i is an array like + object storing instance weights for the i-th validation set. + base_margin_eval_set : list, optional + A list of the form [M_1, M_2, ..., M_n], where each M_i is an array like + object storing base margin for the i-th validation set. feature_weights: array_like Weight for each feature, defines the probability of each feature being selected when colsample is being used. All values must be greater than 0, @@ -655,12 +705,21 @@ class XGBModel(XGBModelBase): """ evals_result = {} - train_dmatrix, evals = self._wrap_evaluation_matrices( - X, y, group=None, qid=None, sample_weight=sample_weight, + train_dmatrix, evals = _wrap_evaluation_matrices( + missing=self.missing, + X=X, + y=y, + group=None, + qid=None, + sample_weight=sample_weight, base_margin=base_margin, - feature_weights=feature_weights, eval_set=eval_set, + feature_weights=feature_weights, + eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set, - eval_group=None, eval_qid=None + base_margin_eval_set=base_margin_eval_set, + eval_group=None, + eval_qid=None, + create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), ) params = self.get_xgb_params() @@ -671,13 +730,19 @@ class XGBModel(XGBModelBase): obj = None model, feval, params = self._configure_fit(xgb_model, eval_metric, params) - self._Booster = train(params, train_dmatrix, - self.get_num_boosting_rounds(), evals=evals, - early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, - obj=obj, feval=feval, - verbose_eval=verbose, xgb_model=model, - callbacks=callbacks) + self._Booster = train( + params, + train_dmatrix, + self.get_num_boosting_rounds(), + evals=evals, + early_stopping_rounds=early_stopping_rounds, + evals_result=evals_result, + obj=obj, + feval=feval, + verbose_eval=verbose, + xgb_model=model, + callbacks=callbacks, + ) self._set_evaluation_result(evals_result) return self @@ -911,7 +976,9 @@ class XGBModel(XGBModelBase): return np.array(json.loads(b.get_dump(dump_format='json')[0])['bias']) -def _cls_predict_proba(objective: Union[str, Callable], prediction: Any, vstack: Callable) -> Any: +def _cls_predict_proba( + objective: Union[str, Callable], prediction: Any, vstack: Callable +) -> Any: if objective == 'multi:softmax': raise ValueError('multi:softmax objective does not support predict_proba,' ' use `multi:softprob` or `binary:logistic` instead.') @@ -931,8 +998,8 @@ def _cls_predict_proba(objective: Union[str, Callable], prediction: Any, vstack: n_estimators : int Number of boosting rounds. use_label_encoder : bool - (Deprecated) Use the label encoder from scikit-learn to encode the labels. For new code, - we recommend that you set this parameter to False. + (Deprecated) Use the label encoder from scikit-learn to encode the labels. For new + code, we recommend that you set this parameter to False. ''') class XGBClassifier(XGBModel, XGBClassifierBase): # pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes @@ -955,45 +1022,55 @@ class XGBClassifier(XGBModel, XGBClassifierBase): verbose=True, xgb_model=None, sample_weight_eval_set=None, + base_margin_eval_set=None, feature_weights=None, callbacks=None ): - # pylint: disable = attribute-defined-outside-init,arguments-differ,too-many-statements - + # pylint: disable = attribute-defined-outside-init,too-many-statements can_use_label_encoder = True label_encoding_check_error = ( - 'The label must consist of integer labels of form 0, 1, 2, ..., [num_class - 1].') + "The label must consist of integer " + "labels of form 0, 1, 2, ..., [num_class - 1]." + ) label_encoder_deprecation_msg = ( - 'The use of label encoder in XGBClassifier is deprecated and will be ' + - 'removed in a future release. To remove this warning, do the ' + - 'following: 1) Pass option use_label_encoder=False when constructing ' + - 'XGBClassifier object; and 2) Encode your labels (y) as integers ' + - 'starting with 0, i.e. 0, 1, 2, ..., [num_class - 1].') + "The use of label encoder in XGBClassifier is deprecated and will be " + "removed in a future release. To remove this warning, do the " + "following: 1) Pass option use_label_encoder=False when constructing " + "XGBClassifier object; and 2) Encode your labels (y) as integers " + "starting with 0, i.e. 0, 1, 2, ..., [num_class - 1]." + ) evals_result = {} if _is_cudf_df(y) or _is_cudf_ser(y): import cupy as cp # pylint: disable=E0401 + self.classes_ = cp.unique(y.values) self.n_classes_ = len(self.classes_) can_use_label_encoder = False expected_classes = cp.arange(self.n_classes_) - if (self.classes_.shape != expected_classes.shape or - not (self.classes_ == expected_classes).all()): + if ( + self.classes_.shape != expected_classes.shape + or not (self.classes_ == expected_classes).all() + ): raise ValueError(label_encoding_check_error) elif _is_cupy_array(y): import cupy as cp # pylint: disable=E0401 + self.classes_ = cp.unique(y) self.n_classes_ = len(self.classes_) can_use_label_encoder = False expected_classes = cp.arange(self.n_classes_) - if (self.classes_.shape != expected_classes.shape or - not (self.classes_ == expected_classes).all()): + if ( + self.classes_.shape != expected_classes.shape + or not (self.classes_ == expected_classes).all() + ): raise ValueError(label_encoding_check_error) else: self.classes_ = np.unique(y) self.n_classes_ = len(self.classes_) if not self.use_label_encoder and ( - not np.array_equal(self.classes_, np.arange(self.n_classes_))): + not np.array_equal(self.classes_, np.arange(self.n_classes_)) + ): raise ValueError(label_encoding_check_error) params = self.get_xgb_params() @@ -1008,8 +1085,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase): if self.n_classes_ > 2: # Switch to using a multiclass objective in the underlying # XGB instance - params['objective'] = 'multi:softprob' - params['num_class'] = self.n_classes_ + params["objective"] = "multi:softprob" + params["num_class"] = self.n_classes_ if self.use_label_encoder: if not can_use_label_encoder: @@ -1021,34 +1098,45 @@ class XGBClassifier(XGBModel, XGBClassifierBase): self._le = XGBoostLabelEncoder().fit(y) label_transform = self._le.transform else: - label_transform = (lambda x: x) + label_transform = lambda x: x model, feval, params = self._configure_fit(xgb_model, eval_metric, params) if len(X.shape) != 2: # Simply raise an error here since there might be many # different ways of reshaping - raise ValueError( - 'Please reshape the input data X into 2-dimensional matrix.') + raise ValueError("Please reshape the input data X into 2-dimensional matrix.") - train_dmatrix, evals = self._wrap_evaluation_matrices( - X, y, group=None, qid=None, + train_dmatrix, evals = _wrap_evaluation_matrices( + missing=self.missing, + X=X, + y=y, + group=None, + qid=None, sample_weight=sample_weight, base_margin=base_margin, feature_weights=feature_weights, eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set, + base_margin_eval_set=base_margin_eval_set, eval_group=None, eval_qid=None, - label_transform=label_transform + create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), + label_transform=label_transform, ) - self._Booster = train(params, train_dmatrix, - self.get_num_boosting_rounds(), - evals=evals, - early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, obj=obj, feval=feval, - verbose_eval=verbose, xgb_model=model, - callbacks=callbacks) + self._Booster = train( + params, + train_dmatrix, + self.get_num_boosting_rounds(), + evals=evals, + early_stopping_rounds=early_stopping_rounds, + evals_result=evals_result, + obj=obj, + feval=feval, + verbose_eval=verbose, + xgb_model=model, + callbacks=callbacks, + ) if not callable(self.objective): self.objective = params["objective"] @@ -1178,8 +1266,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase): n_estimators : int Number of trees in random forest to fit. use_label_encoder : bool - (Deprecated) Use the label encoder from scikit-learn to encode the labels. For new code, - we recommend that you set this parameter to False. + (Deprecated) Use the label encoder from scikit-learn to encode the labels. For new + code, we recommend that you set this parameter to False. ''') class XGBRFClassifier(XGBClassifier): # pylint: disable=missing-docstring @@ -1285,11 +1373,10 @@ class XGBRFRegressor(XGBRegressor): class XGBRanker(XGBModel, XGBRankerMixIn): # pylint: disable=missing-docstring,too-many-arguments,invalid-name @_deprecate_positional_args - def __init__(self, *, objective='rank:pairwise', **kwargs): + def __init__(self, *, objective="rank:pairwise", **kwargs): super().__init__(objective=objective, **kwargs) if callable(self.objective): - raise ValueError( - "custom objective function not supported by XGBRanker") + raise ValueError("custom objective function not supported by XGBRanker") if "rank:" not in self.objective: raise ValueError("please use XGBRanker for ranking task") @@ -1304,22 +1391,23 @@ class XGBRanker(XGBModel, XGBRankerMixIn): sample_weight=None, base_margin=None, eval_set=None, - sample_weight_eval_set=None, eval_group=None, eval_qid=None, eval_metric=None, early_stopping_rounds=None, verbose=False, - xgb_model=None, + xgb_model: Optional[Union[Booster, str, XGBModel]] = None, + sample_weight_eval_set=None, + base_margin_eval_set=None, feature_weights=None, callbacks=None ) -> "XGBRanker": # pylint: disable = attribute-defined-outside-init,arguments-differ """Fit gradient boosting ranker - Note that calling ``fit()`` multiple times will cause the model object to be re-fit from - scratch. To resume training from a previous checkpoint, explicitly pass ``xgb_model`` - argument. + Note that calling ``fit()`` multiple times will cause the model object to be + re-fit from scratch. To resume training from a previous checkpoint, explicitly + pass ``xgb_model`` argument. Parameters ---------- @@ -1343,24 +1431,12 @@ class XGBRanker(XGBModel, XGBRankerMixIn): data point). This is because we only care about the relative ordering of data points within each group, so it doesn't make sense to assign weights to individual data points. - base_margin : array_like Global bias for each instance. eval_set : list, optional A list of (X, y) tuple pairs to use as validation sets, for which metrics will be computed. Validation metrics will help us track the performance of the model. - sample_weight_eval_set : list, optional - A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of - group weights on the i-th validation set. - - .. note:: Weights are per-group for ranking tasks - - In ranking task, one weight is assigned to each query group (not each - data point). This is because we only care about the relative ordering of - data points within each group, so it doesn't make sense to assign - weights to individual data points. - eval_group : list of arrays, optional A list in which ``eval_group[i]`` is the list containing the sizes of all query groups in the ``i``-th pair in **eval_set**. @@ -1374,22 +1450,34 @@ class XGBRanker(XGBModel, XGBRankerMixIn): to use. The custom evaluation metric is not yet supported for the ranker. early_stopping_rounds : int Activates early stopping. Validation metric needs to improve at least once in - every **early_stopping_rounds** round(s) to continue training. - Requires at least one item in **eval_set**. - The method returns the model from the last iteration (not the best one). - If there's more than one item in **eval_set**, the last entry will be used - for early stopping. - If there's more than one metric in **eval_metric**, the last metric - will be used for early stopping. - If early stopping occurs, the model will have three additional - fields: ``clf.best_score``, ``clf.best_iteration`` and - ``clf.best_ntree_limit``. + every **early_stopping_rounds** round(s) to continue training. Requires at + least one item in **eval_set**. + The method returns the model from the last iteration (not the best one). If + there's more than one item in **eval_set**, the last entry will be used for + early stopping. + If there's more than one metric in **eval_metric**, the last metric will be + used for early stopping. + If early stopping occurs, the model will have three additional fields: + ``clf.best_score``, ``clf.best_iteration`` and ``clf.best_ntree_limit``. verbose : bool - If `verbose` and an evaluation set is used, writes the evaluation - metric measured on the validation set to stderr. - xgb_model : Union[str, Booster, XGBModel] - file name of stored XGBoost model or 'Booster' instance XGBoost - model to be loaded before training (allows training continuation). + If `verbose` and an evaluation set is used, writes the evaluation metric + measured on the validation set to stderr. + xgb_model : + file name of stored XGBoost model or 'Booster' instance XGBoost model to be + loaded before training (allows training continuation). + sample_weight_eval_set : list, optional + A list of the form [L_1, L_2, ..., L_n], where each L_i is a list of + group weights on the i-th validation set. + + .. note:: Weights are per-group for ranking tasks + + In ranking task, one weight is assigned to each query group (not each + data point). This is because we only care about the relative ordering of + data points within each group, so it doesn't make sense to assign + weights to individual data points. + base_margin_eval_set : list, optional + A list of the form [M_1, M_2, ..., M_n], where each M_i is an array like + object storing base margin for the i-th validation set. feature_weights: array_like Weight for each feature, defines the probability of each feature being selected when colsample is being used. All values must be greater than 0, @@ -1406,6 +1494,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn): save_best=True)] """ + # check if group information is provided if group is None and qid is None: raise ValueError("group or qid is required for ranking task") @@ -1413,24 +1502,10 @@ class XGBRanker(XGBModel, XGBRankerMixIn): if eval_group is None and eval_qid is None: raise ValueError( "eval_group or eval_qid is required if eval_set is not None") - if ( - (eval_group is not None and len(eval_group) != len(eval_set)) and - (eval_qid is not None and len(eval_qid) != len(eval_set)) - ): - raise ValueError( - "length of eval_group or eval_qid should match that of eval_set" - ) - for i in range(len(eval_set)): - if ( - (eval_group is not None and eval_group[i] is not None) and - (eval_qid is not None and eval_qid[i] is not None) - ): - raise ValueError( - "Only one of the qid or group should be provided." - ) - - train_dmatrix, evals = self._wrap_evaluation_matrices( - X, y, + train_dmatrix, evals = _wrap_evaluation_matrices( + missing=self.missing, + X=X, + y=y, group=group, qid=qid, sample_weight=sample_weight, @@ -1438,54 +1513,32 @@ class XGBRanker(XGBModel, XGBRankerMixIn): feature_weights=feature_weights, eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set, + base_margin_eval_set=base_margin_eval_set, eval_group=eval_group, - eval_qid=eval_qid + eval_qid=eval_qid, + create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs), ) - if qid is not None: - train_dmatrix.set_info(qid=qid) - elif group is not None: - train_dmatrix.set_info(group=group) - else: - raise ValueError("Either qid or group should be provided for ranking task.") - - if evals is not None: - for i, e in enumerate(evals): - if eval_qid is not None and eval_qid[i] is not None: - assert eval_group is None or eval_group[i] is None, ( - 'Only one of the eval_qid or eval_group for each evaluation ' - 'dataset should be provided.' - ) - e[0].set_info(qid=qid) - elif eval_group is not None and eval_group[i] is not None: - e[0].set_info(group=eval_group[i]) - else: - raise ValueError( - 'Either eval_qid or eval_group for each evaluation dataset should' - ' be provided for ranking task.' - ) - evals_result = {} params = self.get_xgb_params() - feval = eval_metric if callable(eval_metric) else None - if eval_metric is not None: - if callable(eval_metric): - raise ValueError( - 'Custom evaluation metric is not yet supported for XGBRanker.') - params.update({'eval_metric': eval_metric}) - if hasattr(xgb_model, '_Booster'): - # Handle the case when xgb_model is a sklearn model object - xgb_model = xgb_model._Booster # pylint: disable=protected-access + model, feval, params = self._configure_fit(xgb_model, eval_metric, params) + if callable(feval): + raise ValueError( + 'Custom evaluation metric is not yet supported for XGBRanker.' + ) - self._Booster = train(params, train_dmatrix, - self.get_num_boosting_rounds(), - early_stopping_rounds=early_stopping_rounds, - evals=evals, - evals_result=evals_result, feval=feval, - verbose_eval=verbose, xgb_model=xgb_model, - callbacks=callbacks) + self._Booster = train( + params, train_dmatrix, + self.n_estimators, + early_stopping_rounds=early_stopping_rounds, + evals=evals, + evals_result=evals_result, feval=feval, + verbose_eval=verbose, xgb_model=model, + callbacks=callbacks + ) self.objective = params["objective"] + self._set_evaluation_result(evals_result) return self diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index da8bd6298..51ef9c1e3 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -314,6 +314,14 @@ class TestDistributedGPU: for i in range(len(ddqdm_names)): assert ddqdm_names[i] == dqdm_names[i] + sig = OrderedDict(signature(xgb.XGBRanker.fit).parameters) + ranker_names = list(sig.keys()) + sig = OrderedDict(signature(xgb.dask.DaskXGBRanker.fit).parameters) + dranker_names = list(sig.keys()) + + for rn, drn in zip(ranker_names, dranker_names): + assert rn == drn + def run_quantile(self, name: str, local_cuda_cluster: LocalCUDACluster) -> None: if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows") diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index dedf6bfe7..4861e19da 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -17,7 +17,7 @@ import subprocess import hypothesis from hypothesis import given, settings, note, HealthCheck from test_updaters import hist_parameter_strategy, exact_parameter_strategy -from test_with_sklearn import run_feature_weights +from test_with_sklearn import run_feature_weights, run_data_initialization if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) @@ -176,6 +176,22 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None: assert np.all(predictions_1.compute() == predictions_2.compute()) + margined = xgb.dask.DaskXGBClassifier(n_estimators=4) + margined.fit( + X=X, y=y, base_margin=margin, eval_set=[(X, y)], base_margin_eval_set=[margin] + ) + + unmargined = xgb.dask.DaskXGBClassifier(n_estimators=4) + unmargined.fit(X=X, y=y, eval_set=[(X, y)], base_margin=margin) + + margined_res = margined.evals_result()['validation_0']['logloss'] + unmargined_res = unmargined.evals_result()['validation_0']['logloss'] + + assert len(margined_res) == len(unmargined_res) + for i in range(len(margined_res)): + # margined is correct one, so smaller error. + assert margined_res[i] < unmargined_res[i] + def test_dask_missing_value_reg(client: "Client") -> None: X_0 = np.ones((20 // 2, kCols)) @@ -955,7 +971,7 @@ class TestWithDask: results_native['validation_0']['rmse']) tm.non_increasing(results_native['validation_0']['rmse']) - def test_data_initialization(self) -> None: + def test_no_duplicated_partition(self) -> None: '''Assert each worker has the correct amount of data, and DMatrix initialization doesn't generate unnecessary copies of data. @@ -995,6 +1011,13 @@ class TestWithDask: # Subtract the on disk resource from each worker assert cnt - n_workers == n_partitions + def test_data_initialization(self, client: "Client") -> None: + """assert that we don't create duplicated DMatrix""" + from sklearn.datasets import load_digits + X, y = load_digits(return_X_y=True) + X, y = dd.from_array(X, chunksize=32), dd.from_array(y, chunksize=32) + run_data_initialization(xgb.dask.DaskDMatrix, xgb.dask.DaskXGBClassifier, X, y) + def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None: X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) Xy = xgb.dask.DaskDMatrix(client, X, y) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 110bde611..6c6d13697 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -717,13 +717,13 @@ def test_validation_weights_xgbmodel(): assert all((logloss_with_weights[i] != logloss_without_weights[i] for i in [0, 1])) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): # length of eval set and sample weight doesn't match. clf.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_train, y_train), (X_test, y_test)], sample_weight_eval_set=[weights_train]) - with pytest.raises(AssertionError): + with pytest.raises(ValueError): cls = xgb.XGBClassifier() cls.fit(X_train, y_train, sample_weight=weights_train, eval_set=[(X_train, y_train), (X_test, y_test)], @@ -1118,19 +1118,9 @@ def run_boost_from_prediction(tree_method): assert np.all(predictions_1 == predictions_2) -@pytest.mark.skipif(**tm.no_sklearn()) -def test_boost_from_prediction_hist(): - run_boost_from_prediction('hist') - - -@pytest.mark.skipif(**tm.no_sklearn()) -def test_boost_from_prediction_approx(): - run_boost_from_prediction('approx') - - -@pytest.mark.skipif(**tm.no_sklearn()) -def test_boost_from_prediction_exact(): - run_boost_from_prediction('exact') +@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"]) +def test_boost_from_prediction(tree_method): + run_boost_from_prediction(tree_method) def test_estimator_type(): @@ -1154,3 +1144,32 @@ def test_estimator_type(): cls = xgb.XGBClassifier() cls.load_model(path) # no error + + +def run_data_initialization(DMatrix, model, X, y): + """Assert that we don't create duplicated DMatrix.""" + + old_init = DMatrix.__init__ + count = [0] + + def new_init(self, **kwargs): + count[0] += 1 + return old_init(self, **kwargs) + + DMatrix.__init__ = new_init + model(n_estimators=1).fit(X, y, eval_set=[(X, y)]) + + assert count[0] == 1 + count[0] = 0 # only 1 DMatrix is created. + + y_copy = y.copy() + model(n_estimators=1).fit(X, y, eval_set=[(X, y_copy)]) + assert count[0] == 2 # a different Python object is considered different + + DMatrix.__init__ = old_init + + +def test_data_initialization(): + from sklearn.datasets import load_digits + X, y = load_digits(return_X_y=True) + run_data_initialization(xgb.DMatrix, xgb.XGBClassifier, X, y)