Support sample weight in sklearn custom objective. (#10050)
This commit is contained in:
parent
69a17d5114
commit
8ea705e4d5
@ -804,10 +804,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
|||||||
|
|
||||||
Otherwise, one can pass a list-like input with the same length as number
|
Otherwise, one can pass a list-like input with the same length as number
|
||||||
of columns in `data`, with the following possible values:
|
of columns in `data`, with the following possible values:
|
||||||
- "c", which represents categorical columns.
|
|
||||||
- "q", which represents numeric columns.
|
- "c", which represents categorical columns.
|
||||||
- "int", which represents integer columns.
|
- "q", which represents numeric columns.
|
||||||
- "i", which represents boolean columns.
|
- "int", which represents integer columns.
|
||||||
|
- "i", which represents boolean columns.
|
||||||
|
|
||||||
Note that, while categorical types are treated differently from
|
Note that, while categorical types are treated differently from
|
||||||
the rest for model fitting purposes, the other types do not influence
|
the rest for model fitting purposes, the other types do not influence
|
||||||
|
|||||||
@ -5,12 +5,14 @@ import json
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from inspect import signature
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Protocol,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
@ -67,14 +69,20 @@ def _can_use_qdm(tree_method: Optional[str]) -> bool:
|
|||||||
return tree_method in ("hist", "gpu_hist", None, "auto")
|
return tree_method in ("hist", "gpu_hist", None, "auto")
|
||||||
|
|
||||||
|
|
||||||
SklObjective = Optional[
|
class _SklObjWProto(Protocol): # pylint: disable=too-few-public-methods
|
||||||
Union[str, Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]]
|
def __call__(
|
||||||
]
|
self,
|
||||||
|
y_true: ArrayLike,
|
||||||
|
y_pred: ArrayLike,
|
||||||
|
sample_weight: Optional[ArrayLike],
|
||||||
|
) -> Tuple[ArrayLike, ArrayLike]: ...
|
||||||
|
|
||||||
|
|
||||||
def _objective_decorator(
|
_SklObjProto = Callable[[ArrayLike, ArrayLike], Tuple[np.ndarray, np.ndarray]]
|
||||||
func: Callable[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]
|
SklObjective = Optional[Union[str, _SklObjWProto, _SklObjProto]]
|
||||||
) -> Objective:
|
|
||||||
|
|
||||||
|
def _objective_decorator(func: Union[_SklObjWProto, _SklObjProto]) -> Objective:
|
||||||
"""Decorate an objective function
|
"""Decorate an objective function
|
||||||
|
|
||||||
Converts an objective function using the typical sklearn metrics
|
Converts an objective function using the typical sklearn metrics
|
||||||
@ -89,6 +97,8 @@ def _objective_decorator(
|
|||||||
The target values
|
The target values
|
||||||
y_pred: array_like of shape [n_samples]
|
y_pred: array_like of shape [n_samples]
|
||||||
The predicted values
|
The predicted values
|
||||||
|
sample_weight :
|
||||||
|
Optional sample weight, None or a ndarray.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -103,10 +113,25 @@ def _objective_decorator(
|
|||||||
``dmatrix.get_label()``
|
``dmatrix.get_label()``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
parameters = signature(func).parameters
|
||||||
|
supports_sw = "sample_weight" in parameters
|
||||||
|
|
||||||
def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]:
|
def inner(preds: np.ndarray, dmatrix: DMatrix) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
"""internal function"""
|
"""Internal function."""
|
||||||
|
sample_weight = dmatrix.get_weight()
|
||||||
labels = dmatrix.get_label()
|
labels = dmatrix.get_label()
|
||||||
return func(labels, preds)
|
|
||||||
|
if sample_weight.size > 0 and not supports_sw:
|
||||||
|
raise ValueError(
|
||||||
|
"Custom objective doesn't have the `sample_weight` parameter while"
|
||||||
|
" sample_weight is used."
|
||||||
|
)
|
||||||
|
if sample_weight.size > 0:
|
||||||
|
fnw = cast(_SklObjWProto, func)
|
||||||
|
return fnw(labels, preds, sample_weight=sample_weight)
|
||||||
|
|
||||||
|
fn = cast(_SklObjProto, func)
|
||||||
|
return fn(labels, preds)
|
||||||
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
@ -172,75 +197,121 @@ def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric:
|
|||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
__estimator_doc = """
|
__estimator_doc = f"""
|
||||||
n_estimators : Optional[int]
|
n_estimators : {Optional[int]}
|
||||||
Number of gradient boosted trees. Equivalent to number of boosting
|
Number of gradient boosted trees. Equivalent to number of boosting
|
||||||
rounds.
|
rounds.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__model_doc = f"""
|
__model_doc = f"""
|
||||||
max_depth : Optional[int]
|
max_depth : {Optional[int]}
|
||||||
|
|
||||||
Maximum tree depth for base learners.
|
Maximum tree depth for base learners.
|
||||||
max_leaves :
|
|
||||||
|
max_leaves : {Optional[int]}
|
||||||
|
|
||||||
Maximum number of leaves; 0 indicates no limit.
|
Maximum number of leaves; 0 indicates no limit.
|
||||||
max_bin :
|
|
||||||
|
max_bin : {Optional[int]}
|
||||||
|
|
||||||
If using histogram-based algorithm, maximum number of bins per feature
|
If using histogram-based algorithm, maximum number of bins per feature
|
||||||
grow_policy :
|
|
||||||
Tree growing policy. 0: favor splitting at nodes closest to the node, i.e. grow
|
grow_policy : {Optional[str]}
|
||||||
depth-wise. 1: favor splitting at nodes with highest loss change.
|
|
||||||
learning_rate : Optional[float]
|
Tree growing policy.
|
||||||
|
|
||||||
|
- depthwise: Favors splitting at nodes closest to the node,
|
||||||
|
- lossguide: Favors splitting at nodes with highest loss change.
|
||||||
|
|
||||||
|
learning_rate : {Optional[float]}
|
||||||
|
|
||||||
Boosting learning rate (xgb's "eta")
|
Boosting learning rate (xgb's "eta")
|
||||||
verbosity : Optional[int]
|
|
||||||
|
verbosity : {Optional[int]}
|
||||||
|
|
||||||
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
The degree of verbosity. Valid values are 0 (silent) - 3 (debug).
|
||||||
|
|
||||||
objective : {SklObjective}
|
objective : {SklObjective}
|
||||||
|
|
||||||
Specify the learning task and the corresponding learning objective or a custom
|
Specify the learning task and the corresponding learning objective or a custom
|
||||||
objective function to be used. For custom objective, see
|
objective function to be used.
|
||||||
:doc:`/tutorials/custom_metric_obj` and :ref:`custom-obj-metric` for more
|
|
||||||
information.
|
For custom objective, see :doc:`/tutorials/custom_metric_obj` and
|
||||||
|
:ref:`custom-obj-metric` for more information, along with the end note for
|
||||||
|
function signatures.
|
||||||
|
|
||||||
|
booster: {Optional[str]}
|
||||||
|
|
||||||
|
Specify which booster to use: ``gbtree``, ``gblinear`` or ``dart``.
|
||||||
|
|
||||||
|
tree_method : {Optional[str]}
|
||||||
|
|
||||||
booster: Optional[str]
|
|
||||||
Specify which booster to use: `gbtree`, `gblinear` or `dart`.
|
|
||||||
tree_method: Optional[str]
|
|
||||||
Specify which tree method to use. Default to auto. If this parameter is set to
|
Specify which tree method to use. Default to auto. If this parameter is set to
|
||||||
default, XGBoost will choose the most conservative option available. It's
|
default, XGBoost will choose the most conservative option available. It's
|
||||||
recommended to study this option from the parameters document :doc:`tree method
|
recommended to study this option from the parameters document :doc:`tree method
|
||||||
</treemethod>`
|
</treemethod>`
|
||||||
n_jobs : Optional[int]
|
|
||||||
|
n_jobs : {Optional[int]}
|
||||||
|
|
||||||
Number of parallel threads used to run xgboost. When used with other
|
Number of parallel threads used to run xgboost. When used with other
|
||||||
Scikit-Learn algorithms like grid search, you may choose which algorithm to
|
Scikit-Learn algorithms like grid search, you may choose which algorithm to
|
||||||
parallelize and balance the threads. Creating thread contention will
|
parallelize and balance the threads. Creating thread contention will
|
||||||
significantly slow down both algorithms.
|
significantly slow down both algorithms.
|
||||||
gamma : Optional[float]
|
|
||||||
(min_split_loss) Minimum loss reduction required to make a further partition on a
|
gamma : {Optional[float]}
|
||||||
leaf node of the tree.
|
|
||||||
min_child_weight : Optional[float]
|
(min_split_loss) Minimum loss reduction required to make a further partition on
|
||||||
|
a leaf node of the tree.
|
||||||
|
|
||||||
|
min_child_weight : {Optional[float]}
|
||||||
|
|
||||||
Minimum sum of instance weight(hessian) needed in a child.
|
Minimum sum of instance weight(hessian) needed in a child.
|
||||||
max_delta_step : Optional[float]
|
|
||||||
|
max_delta_step : {Optional[float]}
|
||||||
|
|
||||||
Maximum delta step we allow each tree's weight estimation to be.
|
Maximum delta step we allow each tree's weight estimation to be.
|
||||||
subsample : Optional[float]
|
|
||||||
|
subsample : {Optional[float]}
|
||||||
|
|
||||||
Subsample ratio of the training instance.
|
Subsample ratio of the training instance.
|
||||||
sampling_method :
|
|
||||||
|
sampling_method : {Optional[str]}
|
||||||
|
|
||||||
Sampling method. Used only by the GPU version of ``hist`` tree method.
|
Sampling method. Used only by the GPU version of ``hist`` tree method.
|
||||||
- ``uniform``: select random training instances uniformly.
|
|
||||||
- ``gradient_based`` select random training instances with higher probability
|
- ``uniform``: Select random training instances uniformly.
|
||||||
|
- ``gradient_based``: Select random training instances with higher probability
|
||||||
when the gradient and hessian are larger. (cf. CatBoost)
|
when the gradient and hessian are larger. (cf. CatBoost)
|
||||||
colsample_bytree : Optional[float]
|
|
||||||
|
colsample_bytree : {Optional[float]}
|
||||||
|
|
||||||
Subsample ratio of columns when constructing each tree.
|
Subsample ratio of columns when constructing each tree.
|
||||||
colsample_bylevel : Optional[float]
|
|
||||||
|
colsample_bylevel : {Optional[float]}
|
||||||
|
|
||||||
Subsample ratio of columns for each level.
|
Subsample ratio of columns for each level.
|
||||||
colsample_bynode : Optional[float]
|
|
||||||
|
colsample_bynode : {Optional[float]}
|
||||||
|
|
||||||
Subsample ratio of columns for each split.
|
Subsample ratio of columns for each split.
|
||||||
reg_alpha : Optional[float]
|
|
||||||
|
reg_alpha : {Optional[float]}
|
||||||
|
|
||||||
L1 regularization term on weights (xgb's alpha).
|
L1 regularization term on weights (xgb's alpha).
|
||||||
reg_lambda : Optional[float]
|
|
||||||
|
reg_lambda : {Optional[float]}
|
||||||
|
|
||||||
L2 regularization term on weights (xgb's lambda).
|
L2 regularization term on weights (xgb's lambda).
|
||||||
scale_pos_weight : Optional[float]
|
|
||||||
|
scale_pos_weight : {Optional[float]}
|
||||||
Balancing of positive and negative weights.
|
Balancing of positive and negative weights.
|
||||||
base_score : Optional[float]
|
|
||||||
|
base_score : {Optional[float]}
|
||||||
|
|
||||||
The initial prediction score of all instances, global bias.
|
The initial prediction score of all instances, global bias.
|
||||||
random_state : Optional[Union[numpy.random.RandomState, numpy.random.Generator, int]]
|
|
||||||
|
random_state : {Optional[Union[np.random.RandomState, np.random.Generator, int]]}
|
||||||
|
|
||||||
Random number seed.
|
Random number seed.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@ -248,34 +319,44 @@ __model_doc = f"""
|
|||||||
Using gblinear booster with shotgun updater is nondeterministic as
|
Using gblinear booster with shotgun updater is nondeterministic as
|
||||||
it uses Hogwild algorithm.
|
it uses Hogwild algorithm.
|
||||||
|
|
||||||
missing : float, default np.nan
|
missing : float
|
||||||
Value in the data which needs to be present as a missing value.
|
|
||||||
num_parallel_tree: Optional[int]
|
Value in the data which needs to be present as a missing value. Default to
|
||||||
|
:py:data:`numpy.nan`.
|
||||||
|
|
||||||
|
num_parallel_tree: {Optional[int]}
|
||||||
|
|
||||||
Used for boosting random forest.
|
Used for boosting random forest.
|
||||||
monotone_constraints : Optional[Union[Dict[str, int], str]]
|
|
||||||
|
monotone_constraints : {Optional[Union[Dict[str, int], str]]}
|
||||||
|
|
||||||
Constraint of variable monotonicity. See :doc:`tutorial </tutorials/monotonic>`
|
Constraint of variable monotonicity. See :doc:`tutorial </tutorials/monotonic>`
|
||||||
for more information.
|
for more information.
|
||||||
interaction_constraints : Optional[Union[str, List[Tuple[str]]]]
|
|
||||||
|
interaction_constraints : {Optional[Union[str, List[Tuple[str]]]]}
|
||||||
|
|
||||||
Constraints for interaction representing permitted interactions. The
|
Constraints for interaction representing permitted interactions. The
|
||||||
constraints must be specified in the form of a nested list, e.g. ``[[0, 1], [2,
|
constraints must be specified in the form of a nested list, e.g. ``[[0, 1], [2,
|
||||||
3, 4]]``, where each inner list is a group of indices of features that are
|
3, 4]]``, where each inner list is a group of indices of features that are
|
||||||
allowed to interact with each other. See :doc:`tutorial
|
allowed to interact with each other. See :doc:`tutorial
|
||||||
</tutorials/feature_interaction_constraint>` for more information
|
</tutorials/feature_interaction_constraint>` for more information
|
||||||
importance_type: Optional[str]
|
|
||||||
|
importance_type: {Optional[str]}
|
||||||
|
|
||||||
The feature importance type for the feature_importances\\_ property:
|
The feature importance type for the feature_importances\\_ property:
|
||||||
|
|
||||||
* For tree model, it's either "gain", "weight", "cover", "total_gain" or
|
* For tree model, it's either "gain", "weight", "cover", "total_gain" or
|
||||||
"total_cover".
|
"total_cover".
|
||||||
* For linear model, only "weight" is defined and it's the normalized coefficients
|
* For linear model, only "weight" is defined and it's the normalized
|
||||||
without bias.
|
coefficients without bias.
|
||||||
|
|
||||||
device : Optional[str]
|
device : {Optional[str]}
|
||||||
|
|
||||||
.. versionadded:: 2.0.0
|
.. versionadded:: 2.0.0
|
||||||
|
|
||||||
Device ordinal, available options are `cpu`, `cuda`, and `gpu`.
|
Device ordinal, available options are `cpu`, `cuda`, and `gpu`.
|
||||||
|
|
||||||
validate_parameters : Optional[bool]
|
validate_parameters : {Optional[bool]}
|
||||||
|
|
||||||
Give warnings for unknown parameter.
|
Give warnings for unknown parameter.
|
||||||
|
|
||||||
@ -283,14 +364,14 @@ __model_doc = f"""
|
|||||||
|
|
||||||
See the same parameter of :py:class:`DMatrix` for details.
|
See the same parameter of :py:class:`DMatrix` for details.
|
||||||
|
|
||||||
feature_types : Optional[FeatureTypes]
|
feature_types : {Optional[FeatureTypes]}
|
||||||
|
|
||||||
.. versionadded:: 1.7.0
|
.. versionadded:: 1.7.0
|
||||||
|
|
||||||
Used for specifying feature types without constructing a dataframe. See
|
Used for specifying feature types without constructing a dataframe. See
|
||||||
:py:class:`DMatrix` for details.
|
:py:class:`DMatrix` for details.
|
||||||
|
|
||||||
max_cat_to_onehot : Optional[int]
|
max_cat_to_onehot : {Optional[int]}
|
||||||
|
|
||||||
.. versionadded:: 1.6.0
|
.. versionadded:: 1.6.0
|
||||||
|
|
||||||
@ -303,7 +384,7 @@ __model_doc = f"""
|
|||||||
categorical feature support. See :doc:`Categorical Data
|
categorical feature support. See :doc:`Categorical Data
|
||||||
</tutorials/categorical>` and :ref:`cat-param` for details.
|
</tutorials/categorical>` and :ref:`cat-param` for details.
|
||||||
|
|
||||||
max_cat_threshold : Optional[int]
|
max_cat_threshold : {Optional[int]}
|
||||||
|
|
||||||
.. versionadded:: 1.7.0
|
.. versionadded:: 1.7.0
|
||||||
|
|
||||||
@ -314,7 +395,7 @@ __model_doc = f"""
|
|||||||
needs to be set to have categorical feature support. See :doc:`Categorical Data
|
needs to be set to have categorical feature support. See :doc:`Categorical Data
|
||||||
</tutorials/categorical>` and :ref:`cat-param` for details.
|
</tutorials/categorical>` and :ref:`cat-param` for details.
|
||||||
|
|
||||||
multi_strategy : Optional[str]
|
multi_strategy : {Optional[str]}
|
||||||
|
|
||||||
.. versionadded:: 2.0.0
|
.. versionadded:: 2.0.0
|
||||||
|
|
||||||
@ -327,7 +408,7 @@ __model_doc = f"""
|
|||||||
- ``one_output_per_tree``: One model for each target.
|
- ``one_output_per_tree``: One model for each target.
|
||||||
- ``multi_output_tree``: Use multi-target trees.
|
- ``multi_output_tree``: Use multi-target trees.
|
||||||
|
|
||||||
eval_metric : Optional[Union[str, List[str], Callable]]
|
eval_metric : {Optional[Union[str, List[str], Callable]]}
|
||||||
|
|
||||||
.. versionadded:: 1.6.0
|
.. versionadded:: 1.6.0
|
||||||
|
|
||||||
@ -360,7 +441,7 @@ __model_doc = f"""
|
|||||||
)
|
)
|
||||||
reg.fit(X, y, eval_set=[(X, y)])
|
reg.fit(X, y, eval_set=[(X, y)])
|
||||||
|
|
||||||
early_stopping_rounds : Optional[int]
|
early_stopping_rounds : {Optional[int]}
|
||||||
|
|
||||||
.. versionadded:: 1.6.0
|
.. versionadded:: 1.6.0
|
||||||
|
|
||||||
@ -383,7 +464,8 @@ __model_doc = f"""
|
|||||||
early stopping. If there's more than one metric in **eval_metric**, the last
|
early stopping. If there's more than one metric in **eval_metric**, the last
|
||||||
metric will be used for early stopping.
|
metric will be used for early stopping.
|
||||||
|
|
||||||
callbacks : Optional[List[TrainingCallback]]
|
callbacks : {Optional[List[TrainingCallback]]}
|
||||||
|
|
||||||
List of callback functions that are applied at end of each iteration.
|
List of callback functions that are applied at end of each iteration.
|
||||||
It is possible to use predefined callbacks by using
|
It is possible to use predefined callbacks by using
|
||||||
:ref:`Callback API <callback_api>`.
|
:ref:`Callback API <callback_api>`.
|
||||||
@ -402,7 +484,8 @@ __model_doc = f"""
|
|||||||
reg = xgboost.XGBRegressor(**params, callbacks=callbacks)
|
reg = xgboost.XGBRegressor(**params, callbacks=callbacks)
|
||||||
reg.fit(X, y)
|
reg.fit(X, y)
|
||||||
|
|
||||||
kwargs : dict, optional
|
kwargs : {Optional[Any]}
|
||||||
|
|
||||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters
|
Keyword arguments for XGBoost Booster object. Full documentation of parameters
|
||||||
can be found :doc:`here </parameter>`.
|
can be found :doc:`here </parameter>`.
|
||||||
Attempting to set a parameter via the constructor args and \\*\\*kwargs
|
Attempting to set a parameter via the constructor args and \\*\\*kwargs
|
||||||
@ -419,13 +502,16 @@ __custom_obj_note = """
|
|||||||
.. note:: Custom objective function
|
.. note:: Custom objective function
|
||||||
|
|
||||||
A custom objective function can be provided for the ``objective``
|
A custom objective function can be provided for the ``objective``
|
||||||
parameter. In this case, it should have the signature
|
parameter. In this case, it should have the signature ``objective(y_true,
|
||||||
``objective(y_true, y_pred) -> grad, hess``:
|
y_pred) -> [grad, hess]`` or ``objective(y_true, y_pred, *, sample_weight)
|
||||||
|
-> [grad, hess]``:
|
||||||
|
|
||||||
y_true: array_like of shape [n_samples]
|
y_true: array_like of shape [n_samples]
|
||||||
The target values
|
The target values
|
||||||
y_pred: array_like of shape [n_samples]
|
y_pred: array_like of shape [n_samples]
|
||||||
The predicted values
|
The predicted values
|
||||||
|
sample_weight :
|
||||||
|
Optional sample weights.
|
||||||
|
|
||||||
grad: array_like of shape [n_samples]
|
grad: array_like of shape [n_samples]
|
||||||
The value of the gradient for each sample point.
|
The value of the gradient for each sample point.
|
||||||
|
|||||||
@ -815,10 +815,15 @@ def softprob_obj(
|
|||||||
return objective
|
return objective
|
||||||
|
|
||||||
|
|
||||||
def ls_obj(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
def ls_obj(
|
||||||
|
y_true: np.ndarray, y_pred: np.ndarray, sample_weight: Optional[np.ndarray] = None
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
"""Least squared error."""
|
"""Least squared error."""
|
||||||
grad = y_pred - y_true
|
grad = y_pred - y_true
|
||||||
hess = np.ones(len(y_true))
|
hess = np.ones(len(y_true))
|
||||||
|
if sample_weight is not None:
|
||||||
|
grad *= sample_weight
|
||||||
|
hess *= sample_weight
|
||||||
return grad, hess
|
return grad, hess
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -100,6 +100,7 @@ class LintersPaths:
|
|||||||
# demo
|
# demo
|
||||||
"demo/json-model/json_parser.py",
|
"demo/json-model/json_parser.py",
|
||||||
"demo/guide-python/external_memory.py",
|
"demo/guide-python/external_memory.py",
|
||||||
|
"demo/guide-python/sklearn_examples.py",
|
||||||
"demo/guide-python/continuation.py",
|
"demo/guide-python/continuation.py",
|
||||||
"demo/guide-python/callbacks.py",
|
"demo/guide-python/callbacks.py",
|
||||||
"demo/guide-python/cat_in_the_dat.py",
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
|
|||||||
@ -517,6 +517,12 @@ def test_regression_with_custom_objective():
|
|||||||
labels = y[test_index]
|
labels = y[test_index]
|
||||||
assert mean_squared_error(preds, labels) < 25
|
assert mean_squared_error(preds, labels) < 25
|
||||||
|
|
||||||
|
w = rng.uniform(low=0.0, high=1.0, size=X.shape[0])
|
||||||
|
reg = xgb.XGBRegressor(objective=tm.ls_obj, n_estimators=25)
|
||||||
|
reg.fit(X, y, sample_weight=w)
|
||||||
|
y_pred = reg.predict(X)
|
||||||
|
assert mean_squared_error(y_true=y, y_pred=y_pred, sample_weight=w) < 25
|
||||||
|
|
||||||
# Test that the custom objective function is actually used
|
# Test that the custom objective function is actually used
|
||||||
class XGBCustomObjectiveException(Exception):
|
class XGBCustomObjectiveException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -1750,9 +1750,20 @@ class TestWithDask:
|
|||||||
)
|
)
|
||||||
tm.non_increasing(results_native["validation_0"]["rmse"])
|
tm.non_increasing(results_native["validation_0"]["rmse"])
|
||||||
|
|
||||||
|
reg = xgb.dask.DaskXGBRegressor(
|
||||||
|
n_estimators=rounds, objective=tm.ls_obj, tree_method="hist"
|
||||||
|
)
|
||||||
|
rng = da.random.RandomState(1994)
|
||||||
|
w = rng.uniform(low=0.0, high=1.0, size=y.shape[0])
|
||||||
|
reg.fit(
|
||||||
|
X, y, sample_weight=w, eval_set=[(X, y)], sample_weight_eval_set=[w]
|
||||||
|
)
|
||||||
|
results_custom = reg.evals_result()
|
||||||
|
tm.non_increasing(results_custom["validation_0"]["rmse"])
|
||||||
|
|
||||||
def test_no_duplicated_partition(self) -> None:
|
def test_no_duplicated_partition(self) -> None:
|
||||||
"""Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
"""Assert each worker has the correct amount of data, and DMatrix initialization
|
||||||
generate unnecessary copies of data.
|
doesn't generate unnecessary copies of data.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with LocalCluster(n_workers=2, dashboard_address=":0") as cluster:
|
with LocalCluster(n_workers=2, dashboard_address=":0") as cluster:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user