Support column-wise data split with in-memory inputs (#9628)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -303,14 +303,14 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
|
||||
|
||||
|
||||
def _validate_feature_info(
|
||||
feature_info: Sequence[str], n_features: int, name: str
|
||||
feature_info: Sequence[str], n_features: int, is_column_split: bool, name: str
|
||||
) -> List[str]:
|
||||
if isinstance(feature_info, str) or not isinstance(feature_info, Sequence):
|
||||
raise TypeError(
|
||||
f"Expecting a sequence of strings for {name}, got: {type(feature_info)}"
|
||||
)
|
||||
feature_info = list(feature_info)
|
||||
if len(feature_info) != n_features and n_features != 0:
|
||||
if len(feature_info) != n_features and n_features != 0 and not is_column_split:
|
||||
msg = (
|
||||
f"{name} must have the same length as the number of data columns, ",
|
||||
f"expected {n_features}, got {len(feature_info)}",
|
||||
@@ -1231,6 +1231,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
_check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def data_split_mode(self) -> DataSplitMode:
|
||||
"""Get the data split mode of the DMatrix.
|
||||
|
||||
.. versionadded:: 2.1.0
|
||||
|
||||
"""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixDataSplitMode(self.handle, ctypes.byref(ret)))
|
||||
return DataSplitMode(ret.value)
|
||||
|
||||
def slice(
|
||||
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
|
||||
) -> "DMatrix":
|
||||
@@ -1298,7 +1308,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
|
||||
# validate feature name
|
||||
feature_names = _validate_feature_info(
|
||||
feature_names, self.num_col(), "feature names"
|
||||
feature_names,
|
||||
self.num_col(),
|
||||
self.data_split_mode() == DataSplitMode.COL,
|
||||
"feature names",
|
||||
)
|
||||
if len(feature_names) != len(set(feature_names)):
|
||||
values, counts = np.unique(
|
||||
@@ -1371,7 +1384,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
return
|
||||
|
||||
feature_types = _validate_feature_info(
|
||||
feature_types, self.num_col(), "feature types"
|
||||
feature_types,
|
||||
self.num_col(),
|
||||
self.data_split_mode() == DataSplitMode.COL,
|
||||
"feature types",
|
||||
)
|
||||
|
||||
feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types]
|
||||
|
||||
@@ -107,6 +107,7 @@ def _from_scipy_csr(
|
||||
nthread: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
"""Initialize data from a CSR matrix."""
|
||||
|
||||
@@ -118,7 +119,11 @@ def _from_scipy_csr(
|
||||
_array_interface(data.indices),
|
||||
_array_interface(data.data),
|
||||
c_bst_ulong(data.shape[1]),
|
||||
make_jcargs(missing=float(missing), nthread=int(nthread)),
|
||||
make_jcargs(
|
||||
missing=float(missing),
|
||||
nthread=int(nthread),
|
||||
data_split_mode=int(data_split_mode),
|
||||
),
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
@@ -139,6 +144,7 @@ def _from_scipy_csc(
|
||||
nthread: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
"""Initialize data from a CSC matrix."""
|
||||
handle = ctypes.c_void_p()
|
||||
@@ -149,7 +155,11 @@ def _from_scipy_csc(
|
||||
_array_interface(data.indices),
|
||||
_array_interface(data.data),
|
||||
c_bst_ulong(data.shape[0]),
|
||||
make_jcargs(missing=float(missing), nthread=int(nthread)),
|
||||
make_jcargs(
|
||||
missing=float(missing),
|
||||
nthread=int(nthread),
|
||||
data_split_mode=int(data_split_mode),
|
||||
),
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
@@ -518,11 +528,14 @@ def _from_pandas_df(
|
||||
nthread: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
data, feature_names, feature_types = _transform_pandas_df(
|
||||
data, enable_categorical, feature_names, feature_types
|
||||
)
|
||||
return _from_numpy_array(data, missing, nthread, feature_names, feature_types)
|
||||
return _from_numpy_array(
|
||||
data, missing, nthread, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
|
||||
|
||||
def _is_pandas_series(data: DataType) -> bool:
|
||||
@@ -970,10 +983,13 @@ def _from_list(
|
||||
n_threads: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
array = np.array(data)
|
||||
_check_data_shape(data)
|
||||
return _from_numpy_array(array, missing, n_threads, feature_names, feature_types)
|
||||
return _from_numpy_array(
|
||||
array, missing, n_threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
|
||||
|
||||
def _is_tuple(data: DataType) -> bool:
|
||||
@@ -986,8 +1002,11 @@ def _from_tuple(
|
||||
n_threads: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
return _from_list(data, missing, n_threads, feature_names, feature_types)
|
||||
return _from_list(
|
||||
data, missing, n_threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
|
||||
|
||||
def _is_iter(data: DataType) -> bool:
|
||||
@@ -1029,12 +1048,21 @@ def dispatch_data_backend(
|
||||
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
||||
_check_data_shape(data)
|
||||
if _is_scipy_csr(data):
|
||||
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
|
||||
return _from_scipy_csr(
|
||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
if _is_scipy_csc(data):
|
||||
return _from_scipy_csc(data, missing, threads, feature_names, feature_types)
|
||||
return _from_scipy_csc(
|
||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
if _is_scipy_coo(data):
|
||||
return _from_scipy_csr(
|
||||
data.tocsr(), missing, threads, feature_names, feature_types
|
||||
data.tocsr(),
|
||||
missing,
|
||||
threads,
|
||||
feature_names,
|
||||
feature_types,
|
||||
data_split_mode,
|
||||
)
|
||||
if _is_np_array_like(data):
|
||||
return _from_numpy_array(
|
||||
@@ -1043,9 +1071,13 @@ def dispatch_data_backend(
|
||||
if _is_uri(data):
|
||||
return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
|
||||
if _is_list(data):
|
||||
return _from_list(data, missing, threads, feature_names, feature_types)
|
||||
return _from_list(
|
||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
if _is_tuple(data):
|
||||
return _from_tuple(data, missing, threads, feature_names, feature_types)
|
||||
return _from_tuple(
|
||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_series(data):
|
||||
@@ -1054,7 +1086,13 @@ def dispatch_data_backend(
|
||||
data = pd.DataFrame(data)
|
||||
if _is_pandas_df(data):
|
||||
return _from_pandas_df(
|
||||
data, enable_categorical, missing, threads, feature_names, feature_types
|
||||
data,
|
||||
enable_categorical,
|
||||
missing,
|
||||
threads,
|
||||
feature_names,
|
||||
feature_types,
|
||||
data_split_mode,
|
||||
)
|
||||
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||
return _from_cudf_df(
|
||||
|
||||
@@ -10,6 +10,7 @@ import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
@@ -34,6 +35,7 @@ import pytest
|
||||
from scipy import sparse
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import RabitTracker
|
||||
from xgboost.core import ArrayLike
|
||||
from xgboost.sklearn import SklObjective
|
||||
from xgboost.testing.data import (
|
||||
@@ -938,3 +940,22 @@ def load_agaricus(path: str) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
|
||||
|
||||
def project_root(path: str) -> str:
|
||||
return normpath(os.path.join(demo_dir(path), os.path.pardir))
|
||||
|
||||
|
||||
def run_with_rabit(world_size: int, test_fn: Callable) -> None:
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
|
||||
tracker.start(world_size)
|
||||
|
||||
def run_worker(rabit_env: Dict[str, Union[str, int]]) -> None:
|
||||
with xgb.collective.CommunicatorContext(**rabit_env):
|
||||
test_fn()
|
||||
|
||||
workers = []
|
||||
for _ in range(world_size):
|
||||
worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),))
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
|
||||
tracker.join()
|
||||
|
||||
Reference in New Issue
Block a user