Support custom metric in sklearn ranker. (#8786)
This commit is contained in:
parent
17b709acb9
commit
225b3158f6
@ -4,6 +4,7 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@ -127,6 +128,49 @@ def _metric_decorator(func: Callable) -> Metric:
|
|||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric:
|
||||||
|
"""Decorate a learning to rank metric."""
|
||||||
|
|
||||||
|
def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
|
||||||
|
y_true = dmatrix.get_label()
|
||||||
|
group_ptr = dmatrix.get_uint_info("group_ptr")
|
||||||
|
if group_ptr.size < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid `group_ptr`. Likely caused by invalid qid or group."
|
||||||
|
)
|
||||||
|
scores = np.empty(group_ptr.size - 1)
|
||||||
|
futures = []
|
||||||
|
weight = dmatrix.get_group()
|
||||||
|
no_weight = weight.size == 0
|
||||||
|
|
||||||
|
def task(i: int) -> float:
|
||||||
|
begin = group_ptr[i - 1]
|
||||||
|
end = group_ptr[i]
|
||||||
|
gy = y_true[begin:end]
|
||||||
|
gp = y_score[begin:end]
|
||||||
|
if gy.size == 1:
|
||||||
|
# Maybe there's a better default? 1.0 because many ranking score
|
||||||
|
# functions have output in range [0, 1].
|
||||||
|
return 1.0
|
||||||
|
return func(gy, gp)
|
||||||
|
|
||||||
|
workers = n_jobs if n_jobs is not None else os.cpu_count()
|
||||||
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||||
|
for i in range(1, group_ptr.size):
|
||||||
|
f = executor.submit(task, i)
|
||||||
|
futures.append(f)
|
||||||
|
|
||||||
|
for i, f in enumerate(futures):
|
||||||
|
scores[i] = f.result()
|
||||||
|
|
||||||
|
if no_weight:
|
||||||
|
return func.__name__, scores.mean()
|
||||||
|
|
||||||
|
return func.__name__, np.average(scores, weights=weight)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
__estimator_doc = """
|
__estimator_doc = """
|
||||||
n_estimators : int
|
n_estimators : int
|
||||||
Number of gradient boosted trees. Equivalent to number of boosting
|
Number of gradient boosted trees. Equivalent to number of boosting
|
||||||
@ -868,7 +912,10 @@ class XGBModel(XGBModelBase):
|
|||||||
metric = eval_metric
|
metric = eval_metric
|
||||||
elif callable(eval_metric):
|
elif callable(eval_metric):
|
||||||
# Parameter from constructor or set_params
|
# Parameter from constructor or set_params
|
||||||
metric = _metric_decorator(eval_metric)
|
if self._get_type() == "ranker":
|
||||||
|
metric = ltr_metric_decorator(eval_metric, self.n_jobs)
|
||||||
|
else:
|
||||||
|
metric = _metric_decorator(eval_metric)
|
||||||
else:
|
else:
|
||||||
params.update({"eval_metric": eval_metric})
|
params.update({"eval_metric": eval_metric})
|
||||||
|
|
||||||
@ -1979,10 +2026,6 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
) = self._configure_fit(
|
) = self._configure_fit(
|
||||||
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
|
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
|
||||||
)
|
)
|
||||||
if callable(metric):
|
|
||||||
raise ValueError(
|
|
||||||
"Custom evaluation metric is not yet supported for XGBRanker."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._Booster = train(
|
self._Booster = train(
|
||||||
params,
|
params,
|
||||||
|
|||||||
@ -154,6 +154,32 @@ def test_ranking():
|
|||||||
np.testing.assert_almost_equal(pred, pred_orig)
|
np.testing.assert_almost_equal(pred, pred_orig)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ranking_metric() -> None:
|
||||||
|
from sklearn.metrics import roc_auc_score
|
||||||
|
|
||||||
|
X, y, qid, w = tm.make_ltr(512, 4, 3, 2)
|
||||||
|
# use auc for test as ndcg_score in sklearn works only on label gain instead of exp
|
||||||
|
# gain.
|
||||||
|
# note that the auc in sklearn is different from the one in XGBoost. The one in
|
||||||
|
# sklearn compares the number of mis-classified docs, while the one in xgboost
|
||||||
|
# compares the number of mis-classified pairs.
|
||||||
|
ltr = xgb.XGBRanker(
|
||||||
|
eval_metric=roc_auc_score, n_estimators=10, tree_method="hist", max_depth=2
|
||||||
|
)
|
||||||
|
ltr.fit(
|
||||||
|
X,
|
||||||
|
y,
|
||||||
|
qid=qid,
|
||||||
|
sample_weight=w,
|
||||||
|
eval_set=[(X, y)],
|
||||||
|
eval_qid=[qid],
|
||||||
|
sample_weight_eval_set=[w],
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
results = ltr.evals_result()
|
||||||
|
assert results["validation_0"]["roc_auc_score"][-1] > 0.6
|
||||||
|
|
||||||
|
|
||||||
def test_stacking_regression():
|
def test_stacking_regression():
|
||||||
from sklearn.datasets import load_diabetes
|
from sklearn.datasets import load_diabetes
|
||||||
from sklearn.ensemble import RandomForestRegressor, StackingRegressor
|
from sklearn.ensemble import RandomForestRegressor, StackingRegressor
|
||||||
@ -1426,10 +1452,10 @@ def test_weighted_evaluation_metric():
|
|||||||
X_train, X_test = X[:1600], X[1600:]
|
X_train, X_test = X[:1600], X[1600:]
|
||||||
y_train, y_test = y[:1600], y[1600:]
|
y_train, y_test = y[:1600], y[1600:]
|
||||||
weights_eval_set = np.random.choice([1, 2], len(X_test))
|
weights_eval_set = np.random.choice([1, 2], len(X_test))
|
||||||
|
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
weights_train = np.random.choice([1, 2], len(X_train))
|
weights_train = np.random.choice([1, 2], len(X_train))
|
||||||
|
|
||||||
clf = xgb.XGBClassifier(
|
clf = xgb.XGBClassifier(
|
||||||
tree_method="hist",
|
tree_method="hist",
|
||||||
eval_metric=log_loss,
|
eval_metric=log_loss,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user