Support list and tuple for QDM. (#8542)

This commit is contained in:
Jiaming Yuan
2022-12-10 01:14:44 +08:00
committed by GitHub
parent 8824b40961
commit deb3edf562
2 changed files with 18 additions and 2 deletions

View File

@@ -19,6 +19,7 @@ import xgboost as xgb
class TestQuantileDMatrix:
def test_basic(self) -> None:
"""Checks for np array, list, tuple."""
n_samples = 234
n_features = 8
@@ -41,6 +42,18 @@ class TestQuantileDMatrix:
assert Xy.num_row() == n_samples
assert Xy.num_col() == n_features
n_samples = 64
data = []
for f in range(n_samples):
row = [f] * n_features
data.append(row)
assert np.array(data).shape == (n_samples, n_features)
Xy = xgb.QuantileDMatrix(data, max_bin=256)
assert Xy.num_row() == n_samples
assert Xy.num_col() == n_features
r = np.arange(1.0, n_samples)
np.testing.assert_allclose(Xy.get_data().toarray()[1:, 0], r)
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9])
def test_with_iterator(self, sparsity: float) -> None:
n_samples_per_batch = 317
@@ -242,6 +255,7 @@ class TestQuantileDMatrix:
)
def test_dtypes(self) -> None:
"""Checks for both np array and pd DataFrame."""
n_samples = 128
n_features = 16
for orig, x in np_dtypes(n_samples, n_features):