Add more tests and doc for QDM. (#10692)

This commit is contained in:
Jiaming Yuan 2024-08-16 23:30:04 +08:00 committed by GitHub
parent 582ea104b5
commit 2258bc870d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 61 additions and 4 deletions

View File

@ -1522,6 +1522,20 @@ class QuantileDMatrix(DMatrix):
.. versionadded:: 1.7.0 .. versionadded:: 1.7.0
Examples
--------
.. code-block::
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
X, y = make_regression()
X_train, X_test, y_train, y_test = train_test_split(X, y)
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
# It's necessary to have the training DMatrix as a reference for valid quantiles.
Xy_test = xgb.QuantileDMatrix(X_test, y_test, ref=Xy_train)
Parameters Parameters
---------- ----------
max_bin : max_bin :

View File

@ -0,0 +1,35 @@
"""QuantileDMatrix related tests."""
import numpy as np
from sklearn.model_selection import train_test_split
import xgboost as xgb
from .data import make_batches
def check_ref_quantile_cut(device: str) -> None:
"""Check obtaining the same cut values given a reference."""
X, y, _ = (
data[0]
for data in make_batches(
n_samples_per_batch=8192,
n_features=16,
n_batches=1,
use_cupy=device.startswith("cuda"),
)
)
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
Xy_valid = xgb.QuantileDMatrix(X_valid, y_valid, ref=Xy_train)
cut_train = Xy_train.get_quantile_cut()
cut_valid = Xy_valid.get_quantile_cut()
np.testing.assert_allclose(cut_train[0], cut_valid[0])
np.testing.assert_allclose(cut_train[1], cut_valid[1])
Xy_valid = xgb.QuantileDMatrix(X_valid, y_valid)
cut_valid = Xy_valid.get_quantile_cut()
assert not np.allclose(cut_train[1], cut_valid[1])

View File

@ -250,10 +250,10 @@ def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None:
check_cut(n_entries, indptr, data, X.dtypes) check_cut(n_entries, indptr, data, X.dtypes)
def check_get_quantile_cut(tree_method: str) -> None: def check_get_quantile_cut(tree_method: str, device: str) -> None:
"""Check the quantile cut getter.""" """Check the quantile cut getter."""
use_cupy = tree_method == "gpu_hist" use_cupy = device.startswith("cuda")
check_get_quantile_cut_device(tree_method, False) check_get_quantile_cut_device(tree_method, False)
if use_cupy: if use_cupy:
check_get_quantile_cut_device(tree_method, True) check_get_quantile_cut_device(tree_method, True)

View File

@ -8,6 +8,7 @@ import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.testing.data import check_inf from xgboost.testing.data import check_inf
from xgboost.testing.data_iter import run_mixed_sparsity from xgboost.testing.data_iter import run_mixed_sparsity
from xgboost.testing.quantile_dmatrix import check_ref_quantile_cut
sys.path.append("tests/python") sys.path.append("tests/python")
import test_quantile_dmatrix as tqd import test_quantile_dmatrix as tqd
@ -142,6 +143,9 @@ class TestQuantileDMatrix:
{"tree_method": "approx", "max_bin": max_bin}, Xy, num_boost_round=4 {"tree_method": "approx", "max_bin": max_bin}, Xy, num_boost_round=4
) )
def test_ref_quantile_cut(self) -> None:
check_ref_quantile_cut("cuda")
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_metainfo(self) -> None: def test_metainfo(self) -> None:
import cupy as cp import cupy as cp

View File

@ -321,4 +321,4 @@ class TestGPUUpdaters:
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_get_quantile_cut(self) -> None: def test_get_quantile_cut(self) -> None:
check_get_quantile_cut("gpu_hist") check_get_quantile_cut("hist", "cuda")

View File

@ -17,6 +17,7 @@ from xgboost.testing import (
) )
from xgboost.testing.data import check_inf, np_dtypes from xgboost.testing.data import check_inf, np_dtypes
from xgboost.testing.data_iter import run_mixed_sparsity from xgboost.testing.data_iter import run_mixed_sparsity
from xgboost.testing.quantile_dmatrix import check_ref_quantile_cut
class TestQuantileDMatrix: class TestQuantileDMatrix:
@ -266,6 +267,9 @@ class TestQuantileDMatrix:
dm_results["dvalid"]["rmse"], qdm_results["valid"]["rmse"] dm_results["dvalid"]["rmse"], qdm_results["valid"]["rmse"]
) )
def test_ref_quantile_cut(self) -> None:
check_ref_quantile_cut("cpu")
def test_ref_dmatrix(self) -> None: def test_ref_dmatrix(self) -> None:
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
self.run_ref_dmatrix(rng, "hist", True) self.run_ref_dmatrix(rng, "hist", True)

View File

@ -412,4 +412,4 @@ class TestTreeMethod:
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
@pytest.mark.parametrize("tree_method", ["hist"]) @pytest.mark.parametrize("tree_method", ["hist"])
def test_get_quantile_cut(self, tree_method: str) -> None: def test_get_quantile_cut(self, tree_method: str) -> None:
check_get_quantile_cut(tree_method) check_get_quantile_cut(tree_method, "cpu")