Support custom metric in sklearn ranker. (#8786)

This commit is contained in:
Jiaming Yuan
2023-02-12 13:14:07 +08:00
committed by GitHub
parent 17b709acb9
commit 225b3158f6
2 changed files with 76 additions and 7 deletions

View File

@@ -4,6 +4,7 @@ import copy
import json
import os
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
Callable,
@@ -127,6 +128,49 @@ def _metric_decorator(func: Callable) -> Metric:
return inner
def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric:
"""Decorate a learning to rank metric."""
def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
y_true = dmatrix.get_label()
group_ptr = dmatrix.get_uint_info("group_ptr")
if group_ptr.size < 2:
raise ValueError(
"Invalid `group_ptr`. Likely caused by invalid qid or group."
)
scores = np.empty(group_ptr.size - 1)
futures = []
weight = dmatrix.get_group()
no_weight = weight.size == 0
def task(i: int) -> float:
begin = group_ptr[i - 1]
end = group_ptr[i]
gy = y_true[begin:end]
gp = y_score[begin:end]
if gy.size == 1:
# Maybe there's a better default? 1.0 because many ranking score
# functions have output in range [0, 1].
return 1.0
return func(gy, gp)
workers = n_jobs if n_jobs is not None else os.cpu_count()
with ThreadPoolExecutor(max_workers=workers) as executor:
for i in range(1, group_ptr.size):
f = executor.submit(task, i)
futures.append(f)
for i, f in enumerate(futures):
scores[i] = f.result()
if no_weight:
return func.__name__, scores.mean()
return func.__name__, np.average(scores, weights=weight)
return inner
__estimator_doc = """
n_estimators : int
Number of gradient boosted trees. Equivalent to number of boosting
@@ -868,7 +912,10 @@ class XGBModel(XGBModelBase):
metric = eval_metric
elif callable(eval_metric):
# Parameter from constructor or set_params
metric = _metric_decorator(eval_metric)
if self._get_type() == "ranker":
metric = ltr_metric_decorator(eval_metric, self.n_jobs)
else:
metric = _metric_decorator(eval_metric)
else:
params.update({"eval_metric": eval_metric})
@@ -1979,10 +2026,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
if callable(metric):
raise ValueError(
"Custom evaluation metric is not yet supported for XGBRanker."
)
self._Booster = train(
params,