Support list and tuple for QDM. (#8542)

This commit is contained in:
Jiaming Yuan
2022-12-10 01:14:44 +08:00
committed by GitHub
parent 8824b40961
commit deb3edf562
2 changed files with 18 additions and 2 deletions

View File

@@ -187,8 +187,7 @@ def _from_numpy_array(
feature_types: Optional[FeatureTypes],
) -> DispatchedDataBackendReturnType:
"""Initialize data from a 2-D numpy matrix."""
if len(data.shape) != 2:
raise ValueError("Expecting 2 dimensional numpy.ndarray, got: ", data.shape)
_check_data_shape(data)
data, _ = _ensure_np_dtype(data, data.dtype)
handle = ctypes.c_void_p()
_check_call(
@@ -1199,6 +1198,8 @@ def _proxy_transform(
return data, None, feature_names, feature_types
if _is_dlpack(data):
return _transform_dlpack(data), None, feature_names, feature_types
if _is_list(data) or _is_tuple(data):
data = np.array(data)
if _is_numpy_array(data):
data, _ = _ensure_np_dtype(data, data.dtype)
return data, None, feature_names, feature_types
@@ -1245,6 +1246,7 @@ def dispatch_proxy_set_data(
raise err
if _is_numpy_array(data):
_check_data_shape(data)
proxy._set_data_from_array(data) # pylint: disable=W0212
return
if _is_scipy_csr(data):