From 8ea705e4d55bb08d3f4d9dcfdab39167169447e8 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 21 Feb 2024 00:43:14 +0800 Subject: [PATCH] Support sample weight in sklearn custom objective. (#10050) --- python-package/xgboost/core.py | 9 +- python-package/xgboost/sklearn.py | 210 ++++++++++++------ python-package/xgboost/testing/__init__.py | 7 +- tests/ci_build/lint_python.py | 1 + tests/python/test_with_sklearn.py | 6 + .../test_with_dask/test_with_dask.py | 15 +- 6 files changed, 179 insertions(+), 69 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index f19078224..36e4bdcf0 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -804,10 +804,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m Otherwise, one can pass a list-like input with the same length as number of columns in `data`, with the following possible values: - - "c", which represents categorical columns. - - "q", which represents numeric columns. - - "int", which represents integer columns. - - "i", which represents boolean columns. + + - "c", which represents categorical columns. + - "q", which represents numeric columns. + - "int", which represents integer columns. + - "i", which represents boolean columns. Note that, while categorical types are treated differently from the rest for model fitting purposes, the other types do not influence diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 5d651948c..c4713a9e4 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -5,12 +5,14 @@ import json import os import warnings from concurrent.futures import ThreadPoolExecutor +from inspect import signature from typing import ( Any, Callable, Dict, List, Optional, + Protocol, Sequence, Tuple, Type, @@ -67,14 +69,20 @@ def _can_use_qdm(tree_method: Optional[str]) -> bool: return tree_method in ("hist", "gpu_hist", None, "auto") -SklObjective = Optional[ - Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]] -] +class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods + def __call__( + self, + y_true: ArrayLike, + y_pred: ArrayLike, + sample_weight: Optional[ArrayLike], + ) -> Tuple[ArrayLike, ArrayLike]: ... -def _objective_decorator( - func: Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] -) -> Objective: +_SklObjProto = Callable[[ArrayLike, ArrayLike], Tuple[np.ndarray, np.ndarray]] +SklObjective = Optional[Union[str, _SklObjWProto, _SklObjProto]] + + +def _objective_decorator(func: Union[_SklObjWProto, _SklObjProto]) -> Objective: """Decorate an objective function Converts an objective function using the typical sklearn metrics @@ -89,6 +97,8 @@ def _objective_decorator( The target values y_pred: array_like of shape [n_samples] The predicted values + sample_weight : + Optional sample weight, None or a ndarray. Returns ------- @@ -103,10 +113,25 @@ def _objective_decorator( ``dmatrix.get_label()`` """ + parameters = signature(func).parameters + supports_sw = "sample_weight" in parameters + def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]: - """internal function""" + """Internal function.""" + sample_weight = dmatrix.get_weight() labels = dmatrix.get_label() - return func(labels, preds) + + if sample_weight.size > 0 and not supports_sw: + raise ValueError( + "Custom objective doesn't have the `sample_weight` parameter while" + " sample_weight is used." + ) + if sample_weight.size > 0: + fnw = cast(_SklObjWProto, func) + return fnw(labels, preds, sample_weight=sample_weight) + + fn = cast(_SklObjProto, func) + return fn(labels, preds) return inner @@ -172,75 +197,121 @@ def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric: return inner -__estimator_doc = """ - n_estimators : Optional[int] +__estimator_doc = f""" + n_estimators : {Optional[int]} Number of gradient boosted trees. Equivalent to number of boosting rounds. """ __model_doc = f""" - max_depth : Optional[int] + max_depth : {Optional[int]} + Maximum tree depth for base learners. - max_leaves : + + max_leaves : {Optional[int]} + Maximum number of leaves; 0 indicates no limit. - max_bin : + + max_bin : {Optional[int]} + If using histogram-based algorithm, maximum number of bins per feature - grow_policy : - Tree growing policy. 0: favor splitting at nodes closest to the node, i.e. grow - depth-wise. 1: favor splitting at nodes with highest loss change. - learning_rate : Optional[float] + + grow_policy : {Optional[str]} + + Tree growing policy. + + - depthwise: Favors splitting at nodes closest to the node, + - lossguide: Favors splitting at nodes with highest loss change. + + learning_rate : {Optional[float]} + Boosting learning rate (xgb's "eta") - verbosity : Optional[int] + + verbosity : {Optional[int]} + The degree of verbosity. Valid values are 0 (silent) - 3 (debug). objective : {SklObjective} Specify the learning task and the corresponding learning objective or a custom - objective function to be used. For custom objective, see - :doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more - information. + objective function to be used. + + For custom objective, see :doc:`/tutorials/custom_metric_obj` and + :ref:`custom-obj-metric` for more information, along with the end note for + function signatures. + + booster: {Optional[str]} + + Specify which booster to use: ``gbtree``, ``gblinear`` or ``dart``. + + tree_method : {Optional[str]} - booster: Optional[str] - Specify which booster to use: `gbtree`, `gblinear` or `dart`. - tree_method: Optional[str] Specify which tree method to use. Default to auto. If this parameter is set to default, XGBoost will choose the most conservative option available. It's recommended to study this option from the parameters document :doc:`tree method ` - n_jobs : Optional[int] + + n_jobs : {Optional[int]} + Number of parallel threads used to run xgboost. When used with other Scikit-Learn algorithms like grid search, you may choose which algorithm to parallelize and balance the threads. Creating thread contention will significantly slow down both algorithms. - gamma : Optional[float] - (min_split_loss) Minimum loss reduction required to make a further partition on a - leaf node of the tree. - min_child_weight : Optional[float] + + gamma : {Optional[float]} + + (min_split_loss) Minimum loss reduction required to make a further partition on + a leaf node of the tree. + + min_child_weight : {Optional[float]} + Minimum sum of instance weight(hessian) needed in a child. - max_delta_step : Optional[float] + + max_delta_step : {Optional[float]} + Maximum delta step we allow each tree's weight estimation to be. - subsample : Optional[float] + + subsample : {Optional[float]} + Subsample ratio of the training instance. - sampling_method : + + sampling_method : {Optional[str]} + Sampling method. Used only by the GPU version of ``hist`` tree method. - - ``uniform``: select random training instances uniformly. - - ``gradient_based`` select random training instances with higher probability + + - ``uniform``: Select random training instances uniformly. + - ``gradient_based``: Select random training instances with higher probability when the gradient and hessian are larger. (cf. CatBoost) - colsample_bytree : Optional[float] + + colsample_bytree : {Optional[float]} + Subsample ratio of columns when constructing each tree. - colsample_bylevel : Optional[float] + + colsample_bylevel : {Optional[float]} + Subsample ratio of columns for each level. - colsample_bynode : Optional[float] + + colsample_bynode : {Optional[float]} + Subsample ratio of columns for each split. - reg_alpha : Optional[float] + + reg_alpha : {Optional[float]} + L1 regularization term on weights (xgb's alpha). - reg_lambda : Optional[float] + + reg_lambda : {Optional[float]} + L2 regularization term on weights (xgb's lambda). - scale_pos_weight : Optional[float] + + scale_pos_weight : {Optional[float]} Balancing of positive and negative weights. - base_score : Optional[float] + + base_score : {Optional[float]} + The initial prediction score of all instances, global bias. - random_state : Optional[Union[numpy.random.RandomState, numpy.random.Generator, int]] + + random_state : {Optional[Union[np.random.RandomState, np.random.Generator, int]]} + Random number seed. .. note:: @@ -248,34 +319,44 @@ __model_doc = f""" Using gblinear booster with shotgun updater is nondeterministic as it uses Hogwild algorithm. - missing : float, default np.nan - Value in the data which needs to be present as a missing value. - num_parallel_tree: Optional[int] + missing : float + + Value in the data which needs to be present as a missing value. Default to + :py:data:`numpy.nan`. + + num_parallel_tree: {Optional[int]} + Used for boosting random forest. - monotone_constraints : Optional[Union[Dict[str, int], str]] + + monotone_constraints : {Optional[Union[Dict[str, int], str]]} + Constraint of variable monotonicity. See :doc:`tutorial ` for more information. - interaction_constraints : Optional[Union[str, List[Tuple[str]]]] + + interaction_constraints : {Optional[Union[str, List[Tuple[str]]]]} + Constraints for interaction representing permitted interactions. The constraints must be specified in the form of a nested list, e.g. ``[[0, 1], [2, 3, 4]]``, where each inner list is a group of indices of features that are allowed to interact with each other. See :doc:`tutorial ` for more information - importance_type: Optional[str] + + importance_type: {Optional[str]} + The feature importance type for the feature_importances\\_ property: * For tree model, it's either "gain", "weight", "cover", "total_gain" or "total_cover". - * For linear model, only "weight" is defined and it's the normalized coefficients - without bias. + * For linear model, only "weight" is defined and it's the normalized + coefficients without bias. - device : Optional[str] + device : {Optional[str]} .. versionadded:: 2.0.0 Device ordinal, available options are `cpu`, `cuda`, and `gpu`. - validate_parameters : Optional[bool] + validate_parameters : {Optional[bool]} Give warnings for unknown parameter. @@ -283,14 +364,14 @@ __model_doc = f""" See the same parameter of :py:class:`DMatrix` for details. - feature_types : Optional[FeatureTypes] + feature_types : {Optional[FeatureTypes]} .. versionadded:: 1.7.0 Used for specifying feature types without constructing a dataframe. See :py:class:`DMatrix` for details. - max_cat_to_onehot : Optional[int] + max_cat_to_onehot : {Optional[int]} .. versionadded:: 1.6.0 @@ -303,7 +384,7 @@ __model_doc = f""" categorical feature support. See :doc:`Categorical Data ` and :ref:`cat-param` for details. - max_cat_threshold : Optional[int] + max_cat_threshold : {Optional[int]} .. versionadded:: 1.7.0 @@ -314,7 +395,7 @@ __model_doc = f""" needs to be set to have categorical feature support. See :doc:`Categorical Data ` and :ref:`cat-param` for details. - multi_strategy : Optional[str] + multi_strategy : {Optional[str]} .. versionadded:: 2.0.0 @@ -327,7 +408,7 @@ __model_doc = f""" - ``one_output_per_tree``: One model for each target. - ``multi_output_tree``: Use multi-target trees. - eval_metric : Optional[Union[str, List[str], Callable]] + eval_metric : {Optional[Union[str, List[str], Callable]]} .. versionadded:: 1.6.0 @@ -360,7 +441,7 @@ __model_doc = f""" ) reg.fit(X, y, eval_set=[(X, y)]) - early_stopping_rounds : Optional[int] + early_stopping_rounds : {Optional[int]} .. versionadded:: 1.6.0 @@ -383,7 +464,8 @@ __model_doc = f""" early stopping. If there's more than one metric in **eval_metric**, the last metric will be used for early stopping. - callbacks : Optional[List[TrainingCallback]] + callbacks : {Optional[List[TrainingCallback]]} + List of callback functions that are applied at end of each iteration. It is possible to use predefined callbacks by using :ref:`Callback API `. @@ -402,7 +484,8 @@ __model_doc = f""" reg = xgboost.XGBRegressor(**params, callbacks=callbacks) reg.fit(X, y) - kwargs : dict, optional + kwargs : {Optional[Any]} + Keyword arguments for XGBoost Booster object. Full documentation of parameters can be found :doc:`here `. Attempting to set a parameter via the constructor args and \\*\\*kwargs @@ -419,13 +502,16 @@ __custom_obj_note = """ .. note:: Custom objective function A custom objective function can be provided for the ``objective`` - parameter. In this case, it should have the signature - ``objective(y_true, y_pred) -> grad, hess``: + parameter. In this case, it should have the signature ``objective(y_true, + y_pred) -> [grad, hess]`` or ``objective(y_true, y_pred, *, sample_weight) + -> [grad, hess]``: y_true: array_like of shape [n_samples] The target values y_pred: array_like of shape [n_samples] The predicted values + sample_weight : + Optional sample weights. grad: array_like of shape [n_samples] The value of the gradient for each sample point. diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 389066f0e..f7d9510fa 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -815,10 +815,15 @@ def softprob_obj( return objective -def ls_obj(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: +def ls_obj( + y_true: np.ndarray, y_pred: np.ndarray, sample_weight: Optional[np.ndarray] = None +) -> Tuple[np.ndarray, np.ndarray]: """Least squared error.""" grad = y_pred - y_true hess = np.ones(len(y_true)) + if sample_weight is not None: + grad *= sample_weight + hess *= sample_weight return grad, hess diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 91b748b4c..741ef7558 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -100,6 +100,7 @@ class LintersPaths: # demo "demo/json-model/json_parser.py", "demo/guide-python/external_memory.py", + "demo/guide-python/sklearn_examples.py", "demo/guide-python/continuation.py", "demo/guide-python/callbacks.py", "demo/guide-python/cat_in_the_dat.py", diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index ede70bb8b..507470724 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -517,6 +517,12 @@ def test_regression_with_custom_objective(): labels = y[test_index] assert mean_squared_error(preds, labels) < 25 + w = rng.uniform(low=0.0, high=1.0, size=X.shape[0]) + reg = xgb.XGBRegressor(objective=tm.ls_obj, n_estimators=25) + reg.fit(X, y, sample_weight=w) + y_pred = reg.predict(X) + assert mean_squared_error(y_true=y, y_pred=y_pred, sample_weight=w) < 25 + # Test that the custom objective function is actually used class XGBCustomObjectiveException(Exception): pass diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index fdf0d64c4..ffea1d058 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -1750,9 +1750,20 @@ class TestWithDask: ) tm.non_increasing(results_native["validation_0"]["rmse"]) + reg = xgb.dask.DaskXGBRegressor( + n_estimators=rounds, objective=tm.ls_obj, tree_method="hist" + ) + rng = da.random.RandomState(1994) + w = rng.uniform(low=0.0, high=1.0, size=y.shape[0]) + reg.fit( + X, y, sample_weight=w, eval_set=[(X, y)], sample_weight_eval_set=[w] + ) + results_custom = reg.evals_result() + tm.non_increasing(results_custom["validation_0"]["rmse"]) + def test_no_duplicated_partition(self) -> None: - """Assert each worker has the correct amount of data, and DMatrix initialization doesn't - generate unnecessary copies of data. + """Assert each worker has the correct amount of data, and DMatrix initialization + doesn't generate unnecessary copies of data. """ with LocalCluster(n_workers=2, dashboard_address=":0") as cluster: