Add more tests and doc for QDM. (#10692)

This commit is contained in:
Jiaming Yuan
2024-08-16 23:30:04 +08:00
committed by GitHub
parent 582ea104b5
commit 2258bc870d
7 changed files with 61 additions and 4 deletions

View File

@@ -1522,6 +1522,20 @@ class QuantileDMatrix(DMatrix):
.. versionadded:: 1.7.0
Examples
--------
.. code-block::
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
X, y = make_regression()
X_train, X_test, y_train, y_test = train_test_split(X, y)
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
# It's necessary to have the training DMatrix as a reference for valid quantiles.
Xy_test = xgb.QuantileDMatrix(X_test, y_test, ref=Xy_train)
Parameters
----------
max_bin :

View File

@@ -0,0 +1,35 @@
"""QuantileDMatrix related tests."""
import numpy as np
from sklearn.model_selection import train_test_split
import xgboost as xgb
from .data import make_batches
def check_ref_quantile_cut(device: str) -> None:
"""Check obtaining the same cut values given a reference."""
X, y, _ = (
data[0]
for data in make_batches(
n_samples_per_batch=8192,
n_features=16,
n_batches=1,
use_cupy=device.startswith("cuda"),
)
)
X_train, X_valid, y_train, y_valid = train_test_split(X, y)
Xy_train = xgb.QuantileDMatrix(X_train, y_train)
Xy_valid = xgb.QuantileDMatrix(X_valid, y_valid, ref=Xy_train)
cut_train = Xy_train.get_quantile_cut()
cut_valid = Xy_valid.get_quantile_cut()
np.testing.assert_allclose(cut_train[0], cut_valid[0])
np.testing.assert_allclose(cut_train[1], cut_valid[1])
Xy_valid = xgb.QuantileDMatrix(X_valid, y_valid)
cut_valid = Xy_valid.get_quantile_cut()
assert not np.allclose(cut_train[1], cut_valid[1])

View File

@@ -250,10 +250,10 @@ def check_get_quantile_cut_device(tree_method: str, use_cupy: bool) -> None:
check_cut(n_entries, indptr, data, X.dtypes)
def check_get_quantile_cut(tree_method: str) -> None:
def check_get_quantile_cut(tree_method: str, device: str) -> None:
"""Check the quantile cut getter."""
use_cupy = tree_method == "gpu_hist"
use_cupy = device.startswith("cuda")
check_get_quantile_cut_device(tree_method, False)
if use_cupy:
check_get_quantile_cut_device(tree_method, True)