Rework the MAP metric. (#8931)
- The new implementation is more strict as only binary labels are accepted. The previous implementation converts values greater than 1 to 1. - Deterministic GPU. (no atomic add). - Fix top-k handling. - Precise definition of MAP. (There are other variants on how to handle top-k). - Refactor GPU ranking tests.
This commit is contained in:
@@ -14,6 +14,7 @@ import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from platform import system
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -443,7 +444,7 @@ def get_mq2008(
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
|
||||
src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
|
||||
target = dpath + "/MQ2008.zip"
|
||||
target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip")
|
||||
if not os.path.exists(target):
|
||||
request.urlretrieve(url=src, filename=target)
|
||||
|
||||
@@ -462,9 +463,9 @@ def get_mq2008(
|
||||
qid_valid,
|
||||
) = load_svmlight_files(
|
||||
(
|
||||
dpath + "MQ2008/Fold1/train.txt",
|
||||
dpath + "MQ2008/Fold1/test.txt",
|
||||
dpath + "MQ2008/Fold1/vali.txt",
|
||||
Path(dpath) / "MQ2008" / "Fold1" / "train.txt",
|
||||
Path(dpath) / "MQ2008" / "Fold1" / "test.txt",
|
||||
Path(dpath) / "MQ2008" / "Fold1" / "vali.txt",
|
||||
),
|
||||
query_id=True,
|
||||
zero_based=False,
|
||||
|
||||
@@ -48,7 +48,12 @@ def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None:
|
||||
def neg_mse(*args: Any, **kwargs: Any) -> float:
|
||||
return -float(mean_squared_error(*args, **kwargs))
|
||||
|
||||
ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method)
|
||||
ranker = xgb.XGBRanker(
|
||||
n_estimators=3,
|
||||
eval_metric=neg_mse,
|
||||
tree_method=tree_method,
|
||||
disable_default_eval_metric=True,
|
||||
)
|
||||
ranker.fit(df, y, eval_set=[(valid_df, y)])
|
||||
score = ranker.score(valid_df, y)
|
||||
assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1])
|
||||
|
||||
Reference in New Issue
Block a user