Add more tests and doc for QDM. (#10692)
This commit is contained in:
parent
582ea104b5
commit
2258bc870d
@ -1522,6 +1522,20 @@ class QuantileDMatrix(DMatrix):
|
||||
|
||||
.. versionadded:: 1.7.0
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
.. code-block::
|
||||
|
||||
from sklearn.datasets import make_regression
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
X, y = make_regression()
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y)
|
||||
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
|
||||
# It's necessary to have the training DMatrix as a reference for valid quantiles.
|
||||
Xy_test = xgb.QuantileDMatrix(X_test, y_test, ref=Xy_train)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_bin :
|
||||
|
||||
35
python-package/xgboost/testing/quantile_dmatrix.py
Normal file
35
python-package/xgboost/testing/quantile_dmatrix.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""QuantileDMatrix related tests."""
|
||||
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
from .data import make_batches
|
||||
|
||||
|
||||
def check_ref_quantile_cut(device: str) -> None:
|
||||
"""Check obtaining the same cut values given a reference."""
|
||||
X, y, _ = (
|
||||
data[0]
|
||||
for data in make_batches(
|
||||
n_samples_per_batch=8192,
|
||||
n_features=16,
|
||||
n_batches=1,
|
||||
use_cupy=device.startswith("cuda"),
|
||||
)
|
||||
)
|
||||
|
||||
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
|
||||
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
|
||||
Xy_valid = xgb.QuantileDMatrix(X_valid, y_valid, ref=Xy_train)
|
||||
|
||||
cut_train = Xy_train.get_quantile_cut()
|
||||
cut_valid = Xy_valid.get_quantile_cut()
|
||||
|
||||
np.testing.assert_allclose(cut_train[0], cut_valid[0])
|
||||
np.testing.assert_allclose(cut_train[1], cut_valid[1])
|
||||
|
||||
Xy_valid = xgb.QuantileDMatrix(X_valid, y_valid)
|
||||
cut_valid = Xy_valid.get_quantile_cut()
|
||||
assert not np.allclose(cut_train[1], cut_valid[1])
|
||||
@ -250,10 +250,10 @@ def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None:
|
||||
check_cut(n_entries, indptr, data, X.dtypes)
|
||||
|
||||
|
||||
def check_get_quantile_cut(tree_method: str) -> None:
|
||||
def check_get_quantile_cut(tree_method: str, device: str) -> None:
|
||||
"""Check the quantile cut getter."""
|
||||
|
||||
use_cupy = tree_method == "gpu_hist"
|
||||
use_cupy = device.startswith("cuda")
|
||||
check_get_quantile_cut_device(tree_method, False)
|
||||
if use_cupy:
|
||||
check_get_quantile_cut_device(tree_method, True)
|
||||
|
||||
@ -8,6 +8,7 @@ import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.testing.data import check_inf
|
||||
from xgboost.testing.data_iter import run_mixed_sparsity
|
||||
from xgboost.testing.quantile_dmatrix import check_ref_quantile_cut
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import test_quantile_dmatrix as tqd
|
||||
@ -142,6 +143,9 @@ class TestQuantileDMatrix:
|
||||
{"tree_method": "approx", "max_bin": max_bin}, Xy, num_boost_round=4
|
||||
)
|
||||
|
||||
def test_ref_quantile_cut(self) -> None:
|
||||
check_ref_quantile_cut("cuda")
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_metainfo(self) -> None:
|
||||
import cupy as cp
|
||||
|
||||
@ -321,4 +321,4 @@ class TestGPUUpdaters:
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_get_quantile_cut(self) -> None:
|
||||
check_get_quantile_cut("gpu_hist")
|
||||
check_get_quantile_cut("hist", "cuda")
|
||||
|
||||
@ -17,6 +17,7 @@ from xgboost.testing import (
|
||||
)
|
||||
from xgboost.testing.data import check_inf, np_dtypes
|
||||
from xgboost.testing.data_iter import run_mixed_sparsity
|
||||
from xgboost.testing.quantile_dmatrix import check_ref_quantile_cut
|
||||
|
||||
|
||||
class TestQuantileDMatrix:
|
||||
@ -266,6 +267,9 @@ class TestQuantileDMatrix:
|
||||
dm_results["dvalid"]["rmse"], qdm_results["valid"]["rmse"]
|
||||
)
|
||||
|
||||
def test_ref_quantile_cut(self) -> None:
|
||||
check_ref_quantile_cut("cpu")
|
||||
|
||||
def test_ref_dmatrix(self) -> None:
|
||||
rng = np.random.RandomState(1994)
|
||||
self.run_ref_dmatrix(rng, "hist", True)
|
||||
|
||||
@ -412,4 +412,4 @@ class TestTreeMethod:
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
@pytest.mark.parametrize("tree_method", ["hist"])
|
||||
def test_get_quantile_cut(self, tree_method: str) -> None:
|
||||
check_get_quantile_cut(tree_method)
|
||||
check_get_quantile_cut(tree_method, "cpu")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user