[backport] Fix ranking with quantile dmatrix and group weight. (#8762) (#8800)

* [backport] Fix ranking with quantile dmatrix and group weight. (#8762)

* backport test utilities.
This commit is contained in:
Jiaming Yuan
2023-02-15 02:45:09 +08:00
committed by GitHub
parent 08a547f5c2
commit f15a6d2b19
4 changed files with 65 additions and 4 deletions

View File

@@ -9,7 +9,9 @@ from testing import (
make_batches,
make_batches_sparse,
make_categorical,
make_ltr,
make_sparse_regression,
predictor_equal,
)
import xgboost as xgb
@@ -218,6 +220,16 @@ class TestQuantileDMatrix:
b = booster.predict(qXy)
np.testing.assert_allclose(a, b)
def test_ltr(self) -> None:
X, y, qid, w = make_ltr(100, 3, 3, 5)
Xy_qdm = xgb.QuantileDMatrix(X, y, qid=qid, weight=w)
Xy = xgb.DMatrix(X, y, qid=qid, weight=w)
xgb.train({"tree_method": "hist", "objective": "rank:ndcg"}, Xy)
from_qdm = xgb.QuantileDMatrix(X, weight=w, ref=Xy_qdm)
from_dm = xgb.QuantileDMatrix(X, weight=w, ref=Xy)
assert predictor_equal(from_qdm, from_dm)
# we don't test empty Quantile DMatrix in single node construction.
@given(
strategies.integers(1, 1000),

View File

@@ -466,7 +466,22 @@ def make_categorical(
return df, label
def _cat_sampled_from():
def make_ltr(
n_samples: int, n_features: int, n_query_groups: int, max_rel: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Make a dataset for testing LTR."""
rng = np.random.default_rng(1994)
X = rng.normal(0, 1.0, size=n_samples * n_features).reshape(n_samples, n_features)
y = rng.integers(0, max_rel, size=n_samples)
qid = rng.integers(0, n_query_groups, size=n_samples)
w = rng.normal(0, 1.0, size=n_query_groups)
w -= np.min(w)
w /= np.max(w)
qid = np.sort(qid)
return X, y, qid, w
def _cat_sampled_from() -> strategies.SearchStrategy:
@strategies.composite
def _make_cat(draw):
n_samples = draw(strategies.integers(2, 512))
@@ -775,6 +790,19 @@ class DirectoryExcursion:
os.remove(f)
def predictor_equal(lhs: xgb.DMatrix, rhs: xgb.DMatrix) -> bool:
"""Assert whether two DMatrices contain the same predictors."""
lcsr = lhs.get_data()
rcsr = rhs.get_data()
return all(
(
np.array_equal(lcsr.data, rcsr.data),
np.array_equal(lcsr.indices, rcsr.indices),
np.array_equal(lcsr.indptr, rcsr.indptr),
)
)
@contextmanager
def captured_output():
"""Reassign stdout temporarily in order to test printed statements