* [backport] Fix ranking with quantile dmatrix and group weight. (#8762) * backport test utilities.
This commit is contained in:
parent
08a547f5c2
commit
f15a6d2b19
@ -58,6 +58,13 @@ void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, Bat
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
auto ellpack = [&]() {
|
auto ellpack = [&]() {
|
||||||
|
// workaround ellpack being initialized from CPU.
|
||||||
|
if (p.gpu_id == Context::kCpuId) {
|
||||||
|
p.gpu_id = ref_->Ctx()->gpu_id;
|
||||||
|
}
|
||||||
|
if (p.gpu_id == Context::kCpuId) {
|
||||||
|
p.gpu_id = 0;
|
||||||
|
}
|
||||||
for (auto const& page : ref_->GetBatches<EllpackPage>(p)) {
|
for (auto const& page : ref_->GetBatches<EllpackPage>(p)) {
|
||||||
GetCutsFromEllpack(page, p_cuts);
|
GetCutsFromEllpack(page, p_cuts);
|
||||||
break;
|
break;
|
||||||
@ -172,9 +179,9 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
while (iter.Next()) {
|
while (iter.Next()) {
|
||||||
if (!p_sketch) {
|
if (!p_sketch) {
|
||||||
p_sketch.reset(new common::HostSketchContainer{batch_param_.max_bin,
|
p_sketch.reset(new common::HostSketchContainer{
|
||||||
proxy->Info().feature_types.ConstHostSpan(),
|
batch_param_.max_bin, proxy->Info().feature_types.ConstHostSpan(), column_sizes,
|
||||||
column_sizes, false, ctx_.Threads()});
|
!proxy->Info().group_ptr_.empty(), ctx_.Threads()});
|
||||||
}
|
}
|
||||||
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
HostAdapterDispatch(proxy, [&](auto const& batch) {
|
||||||
proxy->Info().num_nonzero_ = batch_nnz[i];
|
proxy->Info().num_nonzero_ = batch_nnz[i];
|
||||||
|
|||||||
@ -139,3 +139,17 @@ class TestDeviceQuantileDMatrix:
|
|||||||
booster.predict(xgb.DMatrix(d_m.get_data())),
|
booster.predict(xgb.DMatrix(d_m.get_data())),
|
||||||
atol=1e-6,
|
atol=1e-6,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_ltr(self) -> None:
|
||||||
|
import cupy as cp
|
||||||
|
X, y, qid, w = tm.make_ltr(100, 3, 3, 5)
|
||||||
|
# make sure GPU is used to run sketching.
|
||||||
|
cpX = cp.array(X)
|
||||||
|
Xy_qdm = xgb.QuantileDMatrix(cpX, y, qid=qid, weight=w)
|
||||||
|
Xy = xgb.DMatrix(X, y, qid=qid, weight=w)
|
||||||
|
xgb.train({"tree_method": "gpu_hist", "objective": "rank:ndcg"}, Xy)
|
||||||
|
|
||||||
|
from_dm = xgb.QuantileDMatrix(X, weight=w, ref=Xy)
|
||||||
|
from_qdm = xgb.QuantileDMatrix(X, weight=w, ref=Xy_qdm)
|
||||||
|
|
||||||
|
assert tm.predictor_equal(from_qdm, from_dm)
|
||||||
|
|||||||
@ -9,7 +9,9 @@ from testing import (
|
|||||||
make_batches,
|
make_batches,
|
||||||
make_batches_sparse,
|
make_batches_sparse,
|
||||||
make_categorical,
|
make_categorical,
|
||||||
|
make_ltr,
|
||||||
make_sparse_regression,
|
make_sparse_regression,
|
||||||
|
predictor_equal,
|
||||||
)
|
)
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
@ -218,6 +220,16 @@ class TestQuantileDMatrix:
|
|||||||
b = booster.predict(qXy)
|
b = booster.predict(qXy)
|
||||||
np.testing.assert_allclose(a, b)
|
np.testing.assert_allclose(a, b)
|
||||||
|
|
||||||
|
def test_ltr(self) -> None:
|
||||||
|
X, y, qid, w = make_ltr(100, 3, 3, 5)
|
||||||
|
Xy_qdm = xgb.QuantileDMatrix(X, y, qid=qid, weight=w)
|
||||||
|
Xy = xgb.DMatrix(X, y, qid=qid, weight=w)
|
||||||
|
xgb.train({"tree_method": "hist", "objective": "rank:ndcg"}, Xy)
|
||||||
|
|
||||||
|
from_qdm = xgb.QuantileDMatrix(X, weight=w, ref=Xy_qdm)
|
||||||
|
from_dm = xgb.QuantileDMatrix(X, weight=w, ref=Xy)
|
||||||
|
assert predictor_equal(from_qdm, from_dm)
|
||||||
|
|
||||||
# we don't test empty Quantile DMatrix in single node construction.
|
# we don't test empty Quantile DMatrix in single node construction.
|
||||||
@given(
|
@given(
|
||||||
strategies.integers(1, 1000),
|
strategies.integers(1, 1000),
|
||||||
|
|||||||
@ -466,7 +466,22 @@ def make_categorical(
|
|||||||
return df, label
|
return df, label
|
||||||
|
|
||||||
|
|
||||||
def _cat_sampled_from():
|
def make_ltr(
|
||||||
|
n_samples: int, n_features: int, n_query_groups: int, max_rel: int
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""Make a dataset for testing LTR."""
|
||||||
|
rng = np.random.default_rng(1994)
|
||||||
|
X = rng.normal(0, 1.0, size=n_samples * n_features).reshape(n_samples, n_features)
|
||||||
|
y = rng.integers(0, max_rel, size=n_samples)
|
||||||
|
qid = rng.integers(0, n_query_groups, size=n_samples)
|
||||||
|
w = rng.normal(0, 1.0, size=n_query_groups)
|
||||||
|
w -= np.min(w)
|
||||||
|
w /= np.max(w)
|
||||||
|
qid = np.sort(qid)
|
||||||
|
return X, y, qid, w
|
||||||
|
|
||||||
|
|
||||||
|
def _cat_sampled_from() -> strategies.SearchStrategy:
|
||||||
@strategies.composite
|
@strategies.composite
|
||||||
def _make_cat(draw):
|
def _make_cat(draw):
|
||||||
n_samples = draw(strategies.integers(2, 512))
|
n_samples = draw(strategies.integers(2, 512))
|
||||||
@ -775,6 +790,19 @@ class DirectoryExcursion:
|
|||||||
os.remove(f)
|
os.remove(f)
|
||||||
|
|
||||||
|
|
||||||
|
def predictor_equal(lhs: xgb.DMatrix, rhs: xgb.DMatrix) -> bool:
|
||||||
|
"""Assert whether two DMatrices contain the same predictors."""
|
||||||
|
lcsr = lhs.get_data()
|
||||||
|
rcsr = rhs.get_data()
|
||||||
|
return all(
|
||||||
|
(
|
||||||
|
np.array_equal(lcsr.data, rcsr.data),
|
||||||
|
np.array_equal(lcsr.indices, rcsr.indices),
|
||||||
|
np.array_equal(lcsr.indptr, rcsr.indptr),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def captured_output():
|
def captured_output():
|
||||||
"""Reassign stdout temporarily in order to test printed statements
|
"""Reassign stdout temporarily in order to test printed statements
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user