Support list and tuple for QDM. (#8542)
This commit is contained in:
parent
8824b40961
commit
deb3edf562
@ -187,8 +187,7 @@ def _from_numpy_array(
|
||||
feature_types: Optional[FeatureTypes],
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
"""Initialize data from a 2-D numpy matrix."""
|
||||
if len(data.shape) != 2:
|
||||
raise ValueError("Expecting 2 dimensional numpy.ndarray, got: ", data.shape)
|
||||
_check_data_shape(data)
|
||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(
|
||||
@ -1199,6 +1198,8 @@ def _proxy_transform(
|
||||
return data, None, feature_names, feature_types
|
||||
if _is_dlpack(data):
|
||||
return _transform_dlpack(data), None, feature_names, feature_types
|
||||
if _is_list(data) or _is_tuple(data):
|
||||
data = np.array(data)
|
||||
if _is_numpy_array(data):
|
||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||
return data, None, feature_names, feature_types
|
||||
@ -1245,6 +1246,7 @@ def dispatch_proxy_set_data(
|
||||
raise err
|
||||
|
||||
if _is_numpy_array(data):
|
||||
_check_data_shape(data)
|
||||
proxy._set_data_from_array(data) # pylint: disable=W0212
|
||||
return
|
||||
if _is_scipy_csr(data):
|
||||
|
||||
@ -19,6 +19,7 @@ import xgboost as xgb
|
||||
|
||||
class TestQuantileDMatrix:
|
||||
def test_basic(self) -> None:
|
||||
"""Checks for np array, list, tuple."""
|
||||
n_samples = 234
|
||||
n_features = 8
|
||||
|
||||
@ -41,6 +42,18 @@ class TestQuantileDMatrix:
|
||||
assert Xy.num_row() == n_samples
|
||||
assert Xy.num_col() == n_features
|
||||
|
||||
n_samples = 64
|
||||
data = []
|
||||
for f in range(n_samples):
|
||||
row = [f] * n_features
|
||||
data.append(row)
|
||||
assert np.array(data).shape == (n_samples, n_features)
|
||||
Xy = xgb.QuantileDMatrix(data, max_bin=256)
|
||||
assert Xy.num_row() == n_samples
|
||||
assert Xy.num_col() == n_features
|
||||
r = np.arange(1.0, n_samples)
|
||||
np.testing.assert_allclose(Xy.get_data().toarray()[1:, 0], r)
|
||||
|
||||
@pytest.mark.parametrize("sparsity", [0.0, 0.1, 0.8, 0.9])
|
||||
def test_with_iterator(self, sparsity: float) -> None:
|
||||
n_samples_per_batch = 317
|
||||
@ -242,6 +255,7 @@ class TestQuantileDMatrix:
|
||||
)
|
||||
|
||||
def test_dtypes(self) -> None:
|
||||
"""Checks for both np array and pd DataFrame."""
|
||||
n_samples = 128
|
||||
n_features = 16
|
||||
for orig, x in np_dtypes(n_samples, n_features):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user