[EM] Get quantile cuts from the extmem qdm. (#10860)
This commit is contained in:
@@ -211,6 +211,7 @@ def check_extmem_qdm(
|
||||
cache="cache",
|
||||
on_host=on_host,
|
||||
)
|
||||
|
||||
Xy_it = xgb.ExtMemQuantileDMatrix(it)
|
||||
with pytest.raises(ValueError, match="Only the `hist`"):
|
||||
booster_it = xgb.train(
|
||||
@@ -227,12 +228,10 @@ def check_extmem_qdm(
|
||||
Xy = xgb.QuantileDMatrix(it)
|
||||
booster = xgb.train({"device": device}, Xy, num_boost_round=8)
|
||||
|
||||
if device == "cpu":
|
||||
# Get cuts from ellpack without CPU-GPU interpolation is not yet supported.
|
||||
cut_it = Xy_it.get_quantile_cut()
|
||||
cut = Xy.get_quantile_cut()
|
||||
np.testing.assert_allclose(cut_it[0], cut[0])
|
||||
np.testing.assert_allclose(cut_it[1], cut[1])
|
||||
cut_it = Xy_it.get_quantile_cut()
|
||||
cut = Xy.get_quantile_cut()
|
||||
np.testing.assert_allclose(cut_it[0], cut[0])
|
||||
np.testing.assert_allclose(cut_it[1], cut[1])
|
||||
|
||||
predt_it = booster_it.predict(Xy_it)
|
||||
predt = booster.predict(Xy)
|
||||
|
||||
Reference in New Issue
Block a user