Fix ranking with quantile dmatrix and group weight. (#8762)

This commit is contained in:
Jiaming Yuan 2023-02-10 20:32:35 +08:00 committed by GitHub
parent ad0ccc6e4f
commit 8a16944664
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 48 additions and 1 deletions

View File

@ -556,6 +556,21 @@ def make_categorical(
return df, label return df, label
def make_ltr(
n_samples: int, n_features: int, n_query_groups: int, max_rel: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Make a dataset for testing LTR."""
rng = np.random.default_rng(1994)
X = rng.normal(0, 1.0, size=n_samples * n_features).reshape(n_samples, n_features)
y = rng.integers(0, max_rel, size=n_samples)
qid = rng.integers(0, n_query_groups, size=n_samples)
w = rng.normal(0, 1.0, size=n_query_groups)
w -= np.min(w)
w /= np.max(w)
qid = np.sort(qid)
return X, y, qid, w
def _cat_sampled_from() -> strategies.SearchStrategy: def _cat_sampled_from() -> strategies.SearchStrategy:
@strategies.composite @strategies.composite
def _make_cat(draw: Callable) -> Tuple[int, int, int, float]: def _make_cat(draw: Callable) -> Tuple[int, int, int, float]:

View File

@ -63,6 +63,13 @@ void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, Bat
} }
}; };
auto ellpack = [&]() { auto ellpack = [&]() {
// workaround ellpack being initialized from CPU.
if (p.gpu_id == Context::kCpuId) {
p.gpu_id = ref_->Ctx()->gpu_id;
}
if (p.gpu_id == Context::kCpuId) {
p.gpu_id = 0;
}
for (auto const& page : ref_->GetBatches<EllpackPage>(p)) { for (auto const& page : ref_->GetBatches<EllpackPage>(p)) {
GetCutsFromEllpack(page, p_cuts); GetCutsFromEllpack(page, p_cuts);
break; break;
@ -205,7 +212,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
h_ft = proxy->Info().feature_types.ConstHostVector(); h_ft = proxy->Info().feature_types.ConstHostVector();
SyncFeatureType(&h_ft); SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{ p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, h_ft, column_sizes, false, batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()}); proxy->Info().data_split_mode == DataSplitMode::kCol, ctx_.Threads()});
} }
HostAdapterDispatch(proxy, [&](auto const& batch) { HostAdapterDispatch(proxy, [&](auto const& batch) {

View File

@ -139,3 +139,17 @@ class TestQuantileDMatrix:
booster.predict(xgb.DMatrix(d_m.get_data())), booster.predict(xgb.DMatrix(d_m.get_data())),
atol=1e-6, atol=1e-6,
) )
def test_ltr(self) -> None:
import cupy as cp
X, y, qid, w = tm.make_ltr(100, 3, 3, 5)
# make sure GPU is used to run sketching.
cpX = cp.array(X)
Xy_qdm = xgb.QuantileDMatrix(cpX, y, qid=qid, weight=w)
Xy = xgb.DMatrix(X, y, qid=qid, weight=w)
xgb.train({"tree_method": "gpu_hist", "objective": "rank:ndcg"}, Xy)
from_dm = xgb.QuantileDMatrix(X, weight=w, ref=Xy)
from_qdm = xgb.QuantileDMatrix(X, weight=w, ref=Xy_qdm)
assert tm.predictor_equal(from_qdm, from_dm)

View File

@ -9,6 +9,7 @@ from xgboost.testing import (
make_batches, make_batches,
make_batches_sparse, make_batches_sparse,
make_categorical, make_categorical,
make_ltr,
make_sparse_regression, make_sparse_regression,
predictor_equal, predictor_equal,
) )
@ -233,6 +234,16 @@ class TestQuantileDMatrix:
b = booster.predict(qXy) b = booster.predict(qXy)
np.testing.assert_allclose(a, b) np.testing.assert_allclose(a, b)
def test_ltr(self) -> None:
X, y, qid, w = make_ltr(100, 3, 3, 5)
Xy_qdm = xgb.QuantileDMatrix(X, y, qid=qid, weight=w)
Xy = xgb.DMatrix(X, y, qid=qid, weight=w)
xgb.train({"tree_method": "hist", "objective": "rank:ndcg"}, Xy)
from_qdm = xgb.QuantileDMatrix(X, weight=w, ref=Xy_qdm)
from_dm = xgb.QuantileDMatrix(X, weight=w, ref=Xy)
assert predictor_equal(from_qdm, from_dm)
# we don't test empty Quantile DMatrix in single node construction. # we don't test empty Quantile DMatrix in single node construction.
@given( @given(
strategies.integers(1, 1000), strategies.integers(1, 1000),