Support sample weight in sklearn custom objective. (#10050)

This commit is contained in:
Jiaming Yuan 2024-02-21 00:43:14 +08:00 committed by GitHub
parent 69a17d5114
commit 8ea705e4d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 179 additions and 69 deletions

View File

@ -804,10 +804,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
Otherwise, one can pass a list-like input with the same length as number Otherwise, one can pass a list-like input with the same length as number
of columns in `data`, with the following possible values: of columns in `data`, with the following possible values:
- "c", which represents categorical columns.
- "q", which represents numeric columns. - "c", which represents categorical columns.
- "int", which represents integer columns. - "q", which represents numeric columns.
- "i", which represents boolean columns. - "int", which represents integer columns.
- "i", which represents boolean columns.
Note that, while categorical types are treated differently from Note that, while categorical types are treated differently from
the rest for model fitting purposes, the other types do not influence the rest for model fitting purposes, the other types do not influence

View File

@ -5,12 +5,14 @@ import json
import os import os
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from inspect import signature
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict, Dict,
List, List,
Optional, Optional,
Protocol,
Sequence, Sequence,
Tuple, Tuple,
Type, Type,
@ -67,14 +69,20 @@ def _can_use_qdm(tree_method: Optional[str]) -> bool:
return tree_method in ("hist", "gpu_hist", None, "auto") return tree_method in ("hist", "gpu_hist", None, "auto")
SklObjective = Optional[ class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]] def __call__(
] self,
y_true: ArrayLike,
y_pred: ArrayLike,
sample_weight: Optional[ArrayLike],
) -> Tuple[ArrayLike, ArrayLike]: ...
def _objective_decorator( _SklObjProto = Callable[[ArrayLike, ArrayLike], Tuple[np.ndarray, np.ndarray]]
func: Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]] SklObjective = Optional[Union[str, _SklObjWProto, _SklObjProto]]
) -> Objective:
def _objective_decorator(func: Union[_SklObjWProto, _SklObjProto]) -> Objective:
"""Decorate an objective function """Decorate an objective function
Converts an objective function using the typical sklearn metrics Converts an objective function using the typical sklearn metrics
@ -89,6 +97,8 @@ def _objective_decorator(
The target values The target values
y_pred: array_like of shape [n_samples] y_pred: array_like of shape [n_samples]
The predicted values The predicted values
sample_weight :
Optional sample weight, None or a ndarray.
Returns Returns
------- -------
@ -103,10 +113,25 @@ def _objective_decorator(
``dmatrix.get_label()`` ``dmatrix.get_label()``
""" """
parameters = signature(func).parameters
supports_sw = "sample_weight" in parameters
def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]: def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]:
"""internal function""" """Internal function."""
sample_weight = dmatrix.get_weight()
labels = dmatrix.get_label() labels = dmatrix.get_label()
return func(labels, preds)
if sample_weight.size > 0 and not supports_sw:
raise ValueError(
"Custom objective doesn't have the `sample_weight` parameter while"
" sample_weight is used."
)
if sample_weight.size > 0:
fnw = cast(_SklObjWProto, func)
return fnw(labels, preds, sample_weight=sample_weight)
fn = cast(_SklObjProto, func)
return fn(labels, preds)
return inner return inner
@ -172,75 +197,121 @@ def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric:
return inner return inner
__estimator_doc = """ __estimator_doc = f"""
n_estimators : Optional[int] n_estimators : {Optional[int]}
Number of gradient boosted trees. Equivalent to number of boosting Number of gradient boosted trees. Equivalent to number of boosting
rounds. rounds.
""" """
__model_doc = f""" __model_doc = f"""
max_depth : Optional[int] max_depth : {Optional[int]}
Maximum tree depth for base learners. Maximum tree depth for base learners.
max_leaves :
max_leaves : {Optional[int]}
Maximum number of leaves; 0 indicates no limit. Maximum number of leaves; 0 indicates no limit.
max_bin :
max_bin : {Optional[int]}
If using histogram-based algorithm, maximum number of bins per feature If using histogram-based algorithm, maximum number of bins per feature
grow_policy :
Tree growing policy. 0: favor splitting at nodes closest to the node, i.e. grow grow_policy : {Optional[str]}
depth-wise. 1: favor splitting at nodes with highest loss change.
learning_rate : Optional[float] Tree growing policy.
- depthwise: Favors splitting at nodes closest to the node,
- lossguide: Favors splitting at nodes with highest loss change.
learning_rate : {Optional[float]}
Boosting learning rate (xgb's "eta") Boosting learning rate (xgb's "eta")
verbosity : Optional[int]
verbosity : {Optional[int]}
The degree of verbosity. Valid values are 0 (silent) - 3 (debug). The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
objective : {SklObjective} objective : {SklObjective}
Specify the learning task and the corresponding learning objective or a custom Specify the learning task and the corresponding learning objective or a custom
objective function to be used. For custom objective, see objective function to be used.
:doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more
information. For custom objective, see :doc:`/tutorials/custom_metric_obj` and
:ref:`custom-obj-metric` for more information, along with the end note for
function signatures.
booster: {Optional[str]}
Specify which booster to use: ``gbtree``, ``gblinear`` or ``dart``.
tree_method : {Optional[str]}
booster: Optional[str]
Specify which booster to use: `gbtree`, `gblinear` or `dart`.
tree_method: Optional[str]
Specify which tree method to use. Default to auto. If this parameter is set to Specify which tree method to use. Default to auto. If this parameter is set to
default, XGBoost will choose the most conservative option available. It's default, XGBoost will choose the most conservative option available. It's
recommended to study this option from the parameters document :doc:`tree method recommended to study this option from the parameters document :doc:`tree method
</treemethod>` </treemethod>`
n_jobs : Optional[int]
n_jobs : {Optional[int]}
Number of parallel threads used to run xgboost. When used with other Number of parallel threads used to run xgboost. When used with other
Scikit-Learn algorithms like grid search, you may choose which algorithm to Scikit-Learn algorithms like grid search, you may choose which algorithm to
parallelize and balance the threads. Creating thread contention will parallelize and balance the threads. Creating thread contention will
significantly slow down both algorithms. significantly slow down both algorithms.
gamma : Optional[float]
(min_split_loss) Minimum loss reduction required to make a further partition on a gamma : {Optional[float]}
leaf node of the tree.
min_child_weight : Optional[float] (min_split_loss) Minimum loss reduction required to make a further partition on
a leaf node of the tree.
min_child_weight : {Optional[float]}
Minimum sum of instance weight(hessian) needed in a child. Minimum sum of instance weight(hessian) needed in a child.
max_delta_step : Optional[float]
max_delta_step : {Optional[float]}
Maximum delta step we allow each tree's weight estimation to be. Maximum delta step we allow each tree's weight estimation to be.
subsample : Optional[float]
subsample : {Optional[float]}
Subsample ratio of the training instance. Subsample ratio of the training instance.
sampling_method :
sampling_method : {Optional[str]}
Sampling method. Used only by the GPU version of ``hist`` tree method. Sampling method. Used only by the GPU version of ``hist`` tree method.
- ``uniform``: select random training instances uniformly.
- ``gradient_based`` select random training instances with higher probability - ``uniform``: Select random training instances uniformly.
- ``gradient_based``: Select random training instances with higher probability
when the gradient and hessian are larger. (cf. CatBoost) when the gradient and hessian are larger. (cf. CatBoost)
colsample_bytree : Optional[float]
colsample_bytree : {Optional[float]}
Subsample ratio of columns when constructing each tree. Subsample ratio of columns when constructing each tree.
colsample_bylevel : Optional[float]
colsample_bylevel : {Optional[float]}
Subsample ratio of columns for each level. Subsample ratio of columns for each level.
colsample_bynode : Optional[float]
colsample_bynode : {Optional[float]}
Subsample ratio of columns for each split. Subsample ratio of columns for each split.
reg_alpha : Optional[float]
reg_alpha : {Optional[float]}
L1 regularization term on weights (xgb's alpha). L1 regularization term on weights (xgb's alpha).
reg_lambda : Optional[float]
reg_lambda : {Optional[float]}
L2 regularization term on weights (xgb's lambda). L2 regularization term on weights (xgb's lambda).
scale_pos_weight : Optional[float]
scale_pos_weight : {Optional[float]}
Balancing of positive and negative weights. Balancing of positive and negative weights.
base_score : Optional[float]
base_score : {Optional[float]}
The initial prediction score of all instances, global bias. The initial prediction score of all instances, global bias.
random_state : Optional[Union[numpy.random.RandomState, numpy.random.Generator, int]]
random_state : {Optional[Union[np.random.RandomState, np.random.Generator, int]]}
Random number seed. Random number seed.
.. note:: .. note::
@ -248,34 +319,44 @@ __model_doc = f"""
Using gblinear booster with shotgun updater is nondeterministic as Using gblinear booster with shotgun updater is nondeterministic as
it uses Hogwild algorithm. it uses Hogwild algorithm.
missing : float, default np.nan missing : float
Value in the data which needs to be present as a missing value.
num_parallel_tree: Optional[int] Value in the data which needs to be present as a missing value. Default to
:py:data:`numpy.nan`.
num_parallel_tree: {Optional[int]}
Used for boosting random forest. Used for boosting random forest.
monotone_constraints : Optional[Union[Dict[str, int], str]]
monotone_constraints : {Optional[Union[Dict[str, int], str]]}
Constraint of variable monotonicity. See :doc:`tutorial </tutorials/monotonic>` Constraint of variable monotonicity. See :doc:`tutorial </tutorials/monotonic>`
for more information. for more information.
interaction_constraints : Optional[Union[str, List[Tuple[str]]]]
interaction_constraints : {Optional[Union[str, List[Tuple[str]]]]}
Constraints for interaction representing permitted interactions. The Constraints for interaction representing permitted interactions. The
constraints must be specified in the form of a nested list, e.g. ``[[0, 1], [2, constraints must be specified in the form of a nested list, e.g. ``[[0, 1], [2,
3, 4]]``, where each inner list is a group of indices of features that are 3, 4]]``, where each inner list is a group of indices of features that are
allowed to interact with each other. See :doc:`tutorial allowed to interact with each other. See :doc:`tutorial
</tutorials/feature_interaction_constraint>` for more information </tutorials/feature_interaction_constraint>` for more information
importance_type: Optional[str]
importance_type: {Optional[str]}
The feature importance type for the feature_importances\\_ property: The feature importance type for the feature_importances\\_ property:
* For tree model, it's either "gain", "weight", "cover", "total_gain" or * For tree model, it's either "gain", "weight", "cover", "total_gain" or
"total_cover". "total_cover".
* For linear model, only "weight" is defined and it's the normalized coefficients * For linear model, only "weight" is defined and it's the normalized
without bias. coefficients without bias.
device : Optional[str] device : {Optional[str]}
.. versionadded:: 2.0.0 .. versionadded:: 2.0.0
Device ordinal, available options are `cpu`, `cuda`, and `gpu`. Device ordinal, available options are `cpu`, `cuda`, and `gpu`.
validate_parameters : Optional[bool] validate_parameters : {Optional[bool]}
Give warnings for unknown parameter. Give warnings for unknown parameter.
@ -283,14 +364,14 @@ __model_doc = f"""
See the same parameter of :py:class:`DMatrix` for details. See the same parameter of :py:class:`DMatrix` for details.
feature_types : Optional[FeatureTypes] feature_types : {Optional[FeatureTypes]}
.. versionadded:: 1.7.0 .. versionadded:: 1.7.0
Used for specifying feature types without constructing a dataframe. See Used for specifying feature types without constructing a dataframe. See
:py:class:`DMatrix` for details. :py:class:`DMatrix` for details.
max_cat_to_onehot : Optional[int] max_cat_to_onehot : {Optional[int]}
.. versionadded:: 1.6.0 .. versionadded:: 1.6.0
@ -303,7 +384,7 @@ __model_doc = f"""
categorical feature support. See :doc:`Categorical Data categorical feature support. See :doc:`Categorical Data
</tutorials/categorical>` and :ref:`cat-param` for details. </tutorials/categorical>` and :ref:`cat-param` for details.
max_cat_threshold : Optional[int] max_cat_threshold : {Optional[int]}
.. versionadded:: 1.7.0 .. versionadded:: 1.7.0
@ -314,7 +395,7 @@ __model_doc = f"""
needs to be set to have categorical feature support. See :doc:`Categorical Data needs to be set to have categorical feature support. See :doc:`Categorical Data
</tutorials/categorical>` and :ref:`cat-param` for details. </tutorials/categorical>` and :ref:`cat-param` for details.
multi_strategy : Optional[str] multi_strategy : {Optional[str]}
.. versionadded:: 2.0.0 .. versionadded:: 2.0.0
@ -327,7 +408,7 @@ __model_doc = f"""
- ``one_output_per_tree``: One model for each target. - ``one_output_per_tree``: One model for each target.
- ``multi_output_tree``: Use multi-target trees. - ``multi_output_tree``: Use multi-target trees.
eval_metric : Optional[Union[str, List[str], Callable]] eval_metric : {Optional[Union[str, List[str], Callable]]}
.. versionadded:: 1.6.0 .. versionadded:: 1.6.0
@ -360,7 +441,7 @@ __model_doc = f"""
) )
reg.fit(X, y, eval_set=[(X, y)]) reg.fit(X, y, eval_set=[(X, y)])
early_stopping_rounds : Optional[int] early_stopping_rounds : {Optional[int]}
.. versionadded:: 1.6.0 .. versionadded:: 1.6.0
@ -383,7 +464,8 @@ __model_doc = f"""
early stopping. If there's more than one metric in **eval_metric**, the last early stopping. If there's more than one metric in **eval_metric**, the last
metric will be used for early stopping. metric will be used for early stopping.
callbacks : Optional[List[TrainingCallback]] callbacks : {Optional[List[TrainingCallback]]}
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using It is possible to use predefined callbacks by using
:ref:`Callback API <callback_api>`. :ref:`Callback API <callback_api>`.
@ -402,7 +484,8 @@ __model_doc = f"""
reg = xgboost.XGBRegressor(**params, callbacks=callbacks) reg = xgboost.XGBRegressor(**params, callbacks=callbacks)
reg.fit(X, y) reg.fit(X, y)
kwargs : dict, optional kwargs : {Optional[Any]}
Keyword arguments for XGBoost Booster object. Full documentation of parameters Keyword arguments for XGBoost Booster object. Full documentation of parameters
can be found :doc:`here </parameter>`. can be found :doc:`here </parameter>`.
Attempting to set a parameter via the constructor args and \\*\\*kwargs Attempting to set a parameter via the constructor args and \\*\\*kwargs
@ -419,13 +502,16 @@ __custom_obj_note = """
.. note:: Custom objective function .. note:: Custom objective function
A custom objective function can be provided for the ``objective`` A custom objective function can be provided for the ``objective``
parameter. In this case, it should have the signature parameter. In this case, it should have the signature ``objective(y_true,
``objective(y_true, y_pred) -> grad, hess``: y_pred) -> [grad, hess]`` or ``objective(y_true, y_pred, *, sample_weight)
-> [grad, hess]``:
y_true: array_like of shape [n_samples] y_true: array_like of shape [n_samples]
The target values The target values
y_pred: array_like of shape [n_samples] y_pred: array_like of shape [n_samples]
The predicted values The predicted values
sample_weight :
Optional sample weights.
grad: array_like of shape [n_samples] grad: array_like of shape [n_samples]
The value of the gradient for each sample point. The value of the gradient for each sample point.

View File

@ -815,10 +815,15 @@ def softprob_obj(
return objective return objective
def ls_obj(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: def ls_obj(
y_true: np.ndarray, y_pred: np.ndarray, sample_weight: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, np.ndarray]:
"""Least squared error.""" """Least squared error."""
grad = y_pred - y_true grad = y_pred - y_true
hess = np.ones(len(y_true)) hess = np.ones(len(y_true))
if sample_weight is not None:
grad *= sample_weight
hess *= sample_weight
return grad, hess return grad, hess

View File

@ -100,6 +100,7 @@ class LintersPaths:
# demo # demo
"demo/json-model/json_parser.py", "demo/json-model/json_parser.py",
"demo/guide-python/external_memory.py", "demo/guide-python/external_memory.py",
"demo/guide-python/sklearn_examples.py",
"demo/guide-python/continuation.py", "demo/guide-python/continuation.py",
"demo/guide-python/callbacks.py", "demo/guide-python/callbacks.py",
"demo/guide-python/cat_in_the_dat.py", "demo/guide-python/cat_in_the_dat.py",

View File

@ -517,6 +517,12 @@ def test_regression_with_custom_objective():
labels = y[test_index] labels = y[test_index]
assert mean_squared_error(preds, labels) < 25 assert mean_squared_error(preds, labels) < 25
w = rng.uniform(low=0.0, high=1.0, size=X.shape[0])
reg = xgb.XGBRegressor(objective=tm.ls_obj, n_estimators=25)
reg.fit(X, y, sample_weight=w)
y_pred = reg.predict(X)
assert mean_squared_error(y_true=y, y_pred=y_pred, sample_weight=w) < 25
# Test that the custom objective function is actually used # Test that the custom objective function is actually used
class XGBCustomObjectiveException(Exception): class XGBCustomObjectiveException(Exception):
pass pass

View File

@ -1750,9 +1750,20 @@ class TestWithDask:
) )
tm.non_increasing(results_native["validation_0"]["rmse"]) tm.non_increasing(results_native["validation_0"]["rmse"])
reg = xgb.dask.DaskXGBRegressor(
n_estimators=rounds, objective=tm.ls_obj, tree_method="hist"
)
rng = da.random.RandomState(1994)
w = rng.uniform(low=0.0, high=1.0, size=y.shape[0])
reg.fit(
X, y, sample_weight=w, eval_set=[(X, y)], sample_weight_eval_set=[w]
)
results_custom = reg.evals_result()
tm.non_increasing(results_custom["validation_0"]["rmse"])
def test_no_duplicated_partition(self) -> None: def test_no_duplicated_partition(self) -> None:
"""Assert each worker has the correct amount of data, and DMatrix initialization doesn't """Assert each worker has the correct amount of data, and DMatrix initialization
generate unnecessary copies of data. doesn't generate unnecessary copies of data.
""" """
with LocalCluster(n_workers=2, dashboard_address=":0") as cluster: with LocalCluster(n_workers=2, dashboard_address=":0") as cluster: