[EM] Get quantile cuts from the extmem qdm. (#10860)

This commit is contained in:
Jiaming Yuan
2024-10-01 00:59:28 +08:00
committed by GitHub
parent 8cf2f7aed8
commit 92f1c48a22
7 changed files with 35 additions and 14 deletions

View File

@@ -211,6 +211,7 @@ def check_extmem_qdm(
cache="cache",
on_host=on_host,
)
Xy_it = xgb.ExtMemQuantileDMatrix(it)
with pytest.raises(ValueError, match="Only the `hist`"):
booster_it = xgb.train(
@@ -227,12 +228,10 @@ def check_extmem_qdm(
Xy = xgb.QuantileDMatrix(it)
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
if device == "cpu":
# Get cuts from ellpack without CPU-GPU interpolation is not yet supported.
cut_it = Xy_it.get_quantile_cut()
cut = Xy.get_quantile_cut()
np.testing.assert_allclose(cut_it[0], cut[0])
np.testing.assert_allclose(cut_it[1], cut[1])
cut_it = Xy_it.get_quantile_cut()
cut = Xy.get_quantile_cut()
np.testing.assert_allclose(cut_it[0], cut[0])
np.testing.assert_allclose(cut_it[1], cut[1])
predt_it = booster_it.predict(Xy_it)
predt = booster.predict(Xy)