Support custom metric in sklearn ranker. (#8786)

This commit is contained in:
Jiaming Yuan 2023-02-12 13:14:07 +08:00 committed by GitHub
parent 17b709acb9
commit 225b3158f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 76 additions and 7 deletions

View File

@ -4,6 +4,7 @@ import copy
import json import json
import os import os
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import ( from typing import (
Any, Any,
Callable, Callable,
@ -127,6 +128,49 @@ def _metric_decorator(func: Callable) -> Metric:
return inner return inner
def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric:
"""Decorate a learning to rank metric."""
def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
y_true = dmatrix.get_label()
group_ptr = dmatrix.get_uint_info("group_ptr")
if group_ptr.size < 2:
raise ValueError(
"Invalid `group_ptr`. Likely caused by invalid qid or group."
)
scores = np.empty(group_ptr.size - 1)
futures = []
weight = dmatrix.get_group()
no_weight = weight.size == 0
def task(i: int) -> float:
begin = group_ptr[i - 1]
end = group_ptr[i]
gy = y_true[begin:end]
gp = y_score[begin:end]
if gy.size == 1:
# Maybe there's a better default? 1.0 because many ranking score
# functions have output in range [0, 1].
return 1.0
return func(gy, gp)
workers = n_jobs if n_jobs is not None else os.cpu_count()
with ThreadPoolExecutor(max_workers=workers) as executor:
for i in range(1, group_ptr.size):
f = executor.submit(task, i)
futures.append(f)
for i, f in enumerate(futures):
scores[i] = f.result()
if no_weight:
return func.__name__, scores.mean()
return func.__name__, np.average(scores, weights=weight)
return inner
__estimator_doc = """ __estimator_doc = """
n_estimators : int n_estimators : int
Number of gradient boosted trees. Equivalent to number of boosting Number of gradient boosted trees. Equivalent to number of boosting
@ -868,7 +912,10 @@ class XGBModel(XGBModelBase):
metric = eval_metric metric = eval_metric
elif callable(eval_metric): elif callable(eval_metric):
# Parameter from constructor or set_params # Parameter from constructor or set_params
metric = _metric_decorator(eval_metric) if self._get_type() == "ranker":
metric = ltr_metric_decorator(eval_metric, self.n_jobs)
else:
metric = _metric_decorator(eval_metric)
else: else:
params.update({"eval_metric": eval_metric}) params.update({"eval_metric": eval_metric})
@ -1979,10 +2026,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
) = self._configure_fit( ) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks xgb_model, eval_metric, params, early_stopping_rounds, callbacks
) )
if callable(metric):
raise ValueError(
"Custom evaluation metric is not yet supported for XGBRanker."
)
self._Booster = train( self._Booster = train(
params, params,

View File

@ -154,6 +154,32 @@ def test_ranking():
np.testing.assert_almost_equal(pred, pred_orig) np.testing.assert_almost_equal(pred, pred_orig)
def test_ranking_metric() -> None:
from sklearn.metrics import roc_auc_score
X, y, qid, w = tm.make_ltr(512, 4, 3, 2)
# use auc for test as ndcg_score in sklearn works only on label gain instead of exp
# gain.
# note that the auc in sklearn is different from the one in XGBoost. The one in
# sklearn compares the number of mis-classified docs, while the one in xgboost
# compares the number of mis-classified pairs.
ltr = xgb.XGBRanker(
eval_metric=roc_auc_score, n_estimators=10, tree_method="hist", max_depth=2
)
ltr.fit(
X,
y,
qid=qid,
sample_weight=w,
eval_set=[(X, y)],
eval_qid=[qid],
sample_weight_eval_set=[w],
verbose=True,
)
results = ltr.evals_result()
assert results["validation_0"]["roc_auc_score"][-1] > 0.6
def test_stacking_regression(): def test_stacking_regression():
from sklearn.datasets import load_diabetes from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor, StackingRegressor from sklearn.ensemble import RandomForestRegressor, StackingRegressor
@ -1426,10 +1452,10 @@ def test_weighted_evaluation_metric():
X_train, X_test = X[:1600], X[1600:] X_train, X_test = X[:1600], X[1600:]
y_train, y_test = y[:1600], y[1600:] y_train, y_test = y[:1600], y[1600:]
weights_eval_set = np.random.choice([1, 2], len(X_test)) weights_eval_set = np.random.choice([1, 2], len(X_test))
np.random.seed(0) np.random.seed(0)
weights_train = np.random.choice([1, 2], len(X_train)) weights_train = np.random.choice([1, 2], len(X_train))
clf = xgb.XGBClassifier( clf = xgb.XGBClassifier(
tree_method="hist", tree_method="hist",
eval_metric=log_loss, eval_metric=log_loss,