Support list and tuple for QDM. (#8542)

This commit is contained in:
Jiaming Yuan 2022-12-10 01:14:44 +08:00 committed by GitHub
parent 8824b40961
commit deb3edf562
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 2 deletions

View File

@ -187,8 +187,7 @@ def _from_numpy_array(
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
"""Initialize data from a 2-D numpy matrix.""" """Initialize data from a 2-D numpy matrix."""
if len(data.shape) != 2: _check_data_shape(data)
raise ValueError("Expecting 2 dimensional numpy.ndarray, got: ", data.shape)
data, _ = _ensure_np_dtype(data, data.dtype) data, _ = _ensure_np_dtype(data, data.dtype)
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call( _check_call(
@ -1199,6 +1198,8 @@ def _proxy_transform(
return data, None, feature_names, feature_types return data, None, feature_names, feature_types
if _is_dlpack(data): if _is_dlpack(data):
return _transform_dlpack(data), None, feature_names, feature_types return _transform_dlpack(data), None, feature_names, feature_types
if _is_list(data) or _is_tuple(data):
data = np.array(data)
if _is_numpy_array(data): if _is_numpy_array(data):
data, _ = _ensure_np_dtype(data, data.dtype) data, _ = _ensure_np_dtype(data, data.dtype)
return data, None, feature_names, feature_types return data, None, feature_names, feature_types
@ -1245,6 +1246,7 @@ def dispatch_proxy_set_data(
raise err raise err
if _is_numpy_array(data): if _is_numpy_array(data):
_check_data_shape(data)
proxy._set_data_from_array(data) # pylint: disable=W0212 proxy._set_data_from_array(data) # pylint: disable=W0212
return return
if _is_scipy_csr(data): if _is_scipy_csr(data):

View File

@ -19,6 +19,7 @@ import xgboost as xgb
class TestQuantileDMatrix: class TestQuantileDMatrix:
def test_basic(self) -> None: def test_basic(self) -> None:
"""Checks for np array, list, tuple."""
n_samples = 234 n_samples = 234
n_features = 8 n_features = 8
@ -41,6 +42,18 @@ class TestQuantileDMatrix:
assert Xy.num_row() == n_samples assert Xy.num_row() == n_samples
assert Xy.num_col() == n_features assert Xy.num_col() == n_features
n_samples = 64
data = []
for f in range(n_samples):
row = [f] * n_features
data.append(row)
assert np.array(data).shape == (n_samples, n_features)
Xy = xgb.QuantileDMatrix(data, max_bin=256)
assert Xy.num_row() == n_samples
assert Xy.num_col() == n_features
r = np.arange(1.0, n_samples)
np.testing.assert_allclose(Xy.get_data().toarray()[1:, 0], r)
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9]) @pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9])
def test_with_iterator(self, sparsity: float) -> None: def test_with_iterator(self, sparsity: float) -> None:
n_samples_per_batch = 317 n_samples_per_batch = 317
@ -242,6 +255,7 @@ class TestQuantileDMatrix:
) )
def test_dtypes(self) -> None: def test_dtypes(self) -> None:
"""Checks for both np array and pd DataFrame."""
n_samples = 128 n_samples = 128
n_features = 16 n_features = 16
for orig, x in np_dtypes(n_samples, n_features): for orig, x in np_dtypes(n_samples, n_features):