Support column-wise data split with in-memory inputs (#9628)

---------

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Rong Ou 2023-10-16 21:16:39 -07:00 committed by GitHub
parent 4d1607eefd
commit da6803b75b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 307 additions and 27 deletions

View File

@ -144,9 +144,7 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
* See :doc:`/tutorials/input_format` for more info. * See :doc:`/tutorials/input_format` for more info.
* \endverbatim * \endverbatim
* - silent (optional): Whether to print message during loading. Default to true. * - silent (optional): Whether to print message during loading. Default to true.
* - data_split_mode (optional): Whether to split by row or column. In distributed mode, the * - data_split_mode (optional): Whether the file was split by row or column beforehand for distributed computing. Default to row.
* file is split accordingly; otherwise this is only an indicator on how the file was split
* beforehand. Default to row.
* \param out a loaded data matrix * \param out a loaded data matrix
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
@ -174,6 +172,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t *indptr, const unsigned *indic
* \param config JSON encoded configuration. Required values are: * \param config JSON encoded configuration. Required values are:
* - missing: Which value to represent missing value. * - missing: Which value to represent missing value.
* - nthread (optional): Number of threads used for initializing DMatrix. * - nthread (optional): Number of threads used for initializing DMatrix.
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
* \param out created dmatrix * \param out created dmatrix
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
@ -186,6 +185,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
* \param config JSON encoded configuration. Required values are: * \param config JSON encoded configuration. Required values are:
* - missing: Which value to represent missing value. * - missing: Which value to represent missing value.
* - nthread (optional): Number of threads used for initializing DMatrix. * - nthread (optional): Number of threads used for initializing DMatrix.
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
* \param out created dmatrix * \param out created dmatrix
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
@ -200,6 +200,7 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *config, DMatr
* \param config JSON encoded configuration. Supported values are: * \param config JSON encoded configuration. Supported values are:
* - missing: Which value to represent missing value. * - missing: Which value to represent missing value.
* - nthread (optional): Number of threads used for initializing DMatrix. * - nthread (optional): Number of threads used for initializing DMatrix.
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
* \param out created dmatrix * \param out created dmatrix
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
@ -266,6 +267,7 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data,
* \param config JSON encoded configuration. Required values are: * \param config JSON encoded configuration. Required values are:
* - missing: Which value to represent missing value. * - missing: Which value to represent missing value.
* - nthread (optional): Number of threads used for initializing DMatrix. * - nthread (optional): Number of threads used for initializing DMatrix.
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
* \param out created dmatrix * \param out created dmatrix
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
@ -278,6 +280,7 @@ XGB_DLL int XGDMatrixCreateFromCudaColumnar(char const *data, char const *config
* \param config JSON encoded configuration. Required values are: * \param config JSON encoded configuration. Required values are:
* - missing: Which value to represent missing value. * - missing: Which value to represent missing value.
* - nthread (optional): Number of threads used for initializing DMatrix. * - nthread (optional): Number of threads used for initializing DMatrix.
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
* \param out created dmatrix * \param out created dmatrix
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
@ -790,6 +793,16 @@ XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, bst_ulong *out);
*/ */
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out); XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
/*!
* \brief Get the data split mode from DMatrix.
*
* \param handle the handle to the DMatrix
* \param out The output of the data split mode
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out);
/** /**
* \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a * \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a
* quantized DMatrix, quantized values are returned instead. * quantized DMatrix, quantized values are returned instead.

View File

@ -303,14 +303,14 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
def _validate_feature_info( def _validate_feature_info(
feature_info: Sequence[str], n_features: int, name: str feature_info: Sequence[str], n_features: int, is_column_split: bool, name: str
) -> List[str]: ) -> List[str]:
if isinstance(feature_info, str) or not isinstance(feature_info, Sequence): if isinstance(feature_info, str) or not isinstance(feature_info, Sequence):
raise TypeError( raise TypeError(
f"Expecting a sequence of strings for {name}, got: {type(feature_info)}" f"Expecting a sequence of strings for {name}, got: {type(feature_info)}"
) )
feature_info = list(feature_info) feature_info = list(feature_info)
if len(feature_info) != n_features and n_features != 0: if len(feature_info) != n_features and n_features != 0 and not is_column_split:
msg = ( msg = (
f"{name} must have the same length as the number of data columns, ", f"{name} must have the same length as the number of data columns, ",
f"expected {n_features}, got {len(feature_info)}", f"expected {n_features}, got {len(feature_info)}",
@ -1231,6 +1231,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
_check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret))) _check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret)))
return ret.value return ret.value
def data_split_mode(self) -> DataSplitMode:
"""Get the data split mode of the DMatrix.
.. versionadded:: 2.1.0
"""
ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixDataSplitMode(self.handle, ctypes.byref(ret)))
return DataSplitMode(ret.value)
def slice( def slice(
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
) -> "DMatrix": ) -> "DMatrix":
@ -1298,7 +1308,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
# validate feature name # validate feature name
feature_names = _validate_feature_info( feature_names = _validate_feature_info(
feature_names, self.num_col(), "feature names" feature_names,
self.num_col(),
self.data_split_mode() == DataSplitMode.COL,
"feature names",
) )
if len(feature_names) != len(set(feature_names)): if len(feature_names) != len(set(feature_names)):
values, counts = np.unique( values, counts = np.unique(
@ -1371,7 +1384,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
return return
feature_types = _validate_feature_info( feature_types = _validate_feature_info(
feature_types, self.num_col(), "feature types" feature_types,
self.num_col(),
self.data_split_mode() == DataSplitMode.COL,
"feature types",
) )
feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types] feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types]

View File

@ -107,6 +107,7 @@ def _from_scipy_csr(
nthread: int, nthread: int,
feature_names: Optional[FeatureNames], feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
"""Initialize data from a CSR matrix.""" """Initialize data from a CSR matrix."""
@ -118,7 +119,11 @@ def _from_scipy_csr(
_array_interface(data.indices), _array_interface(data.indices),
_array_interface(data.data), _array_interface(data.data),
c_bst_ulong(data.shape[1]), c_bst_ulong(data.shape[1]),
make_jcargs(missing=float(missing), nthread=int(nthread)), make_jcargs(
missing=float(missing),
nthread=int(nthread),
data_split_mode=int(data_split_mode),
),
ctypes.byref(handle), ctypes.byref(handle),
) )
) )
@ -139,6 +144,7 @@ def _from_scipy_csc(
nthread: int, nthread: int,
feature_names: Optional[FeatureNames], feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
"""Initialize data from a CSC matrix.""" """Initialize data from a CSC matrix."""
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
@ -149,7 +155,11 @@ def _from_scipy_csc(
_array_interface(data.indices), _array_interface(data.indices),
_array_interface(data.data), _array_interface(data.data),
c_bst_ulong(data.shape[0]), c_bst_ulong(data.shape[0]),
make_jcargs(missing=float(missing), nthread=int(nthread)), make_jcargs(
missing=float(missing),
nthread=int(nthread),
data_split_mode=int(data_split_mode),
),
ctypes.byref(handle), ctypes.byref(handle),
) )
) )
@ -518,11 +528,14 @@ def _from_pandas_df(
nthread: int, nthread: int,
feature_names: Optional[FeatureNames], feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
data, feature_names, feature_types = _transform_pandas_df( data, feature_names, feature_types = _transform_pandas_df(
data, enable_categorical, feature_names, feature_types data, enable_categorical, feature_names, feature_types
) )
return _from_numpy_array(data, missing, nthread, feature_names, feature_types) return _from_numpy_array(
data, missing, nthread, feature_names, feature_types, data_split_mode
)
def _is_pandas_series(data: DataType) -> bool: def _is_pandas_series(data: DataType) -> bool:
@ -970,10 +983,13 @@ def _from_list(
n_threads: int, n_threads: int,
feature_names: Optional[FeatureNames], feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
array = np.array(data) array = np.array(data)
_check_data_shape(data) _check_data_shape(data)
return _from_numpy_array(array, missing, n_threads, feature_names, feature_types) return _from_numpy_array(
array, missing, n_threads, feature_names, feature_types, data_split_mode
)
def _is_tuple(data: DataType) -> bool: def _is_tuple(data: DataType) -> bool:
@ -986,8 +1002,11 @@ def _from_tuple(
n_threads: int, n_threads: int,
feature_names: Optional[FeatureNames], feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes], feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType: ) -> DispatchedDataBackendReturnType:
return _from_list(data, missing, n_threads, feature_names, feature_types) return _from_list(
data, missing, n_threads, feature_names, feature_types, data_split_mode
)
def _is_iter(data: DataType) -> bool: def _is_iter(data: DataType) -> bool:
@ -1029,12 +1048,21 @@ def dispatch_data_backend(
if not _is_cudf_ser(data) and not _is_pandas_series(data): if not _is_cudf_ser(data) and not _is_pandas_series(data):
_check_data_shape(data) _check_data_shape(data)
if _is_scipy_csr(data): if _is_scipy_csr(data):
return _from_scipy_csr(data, missing, threads, feature_names, feature_types) return _from_scipy_csr(
data, missing, threads, feature_names, feature_types, data_split_mode
)
if _is_scipy_csc(data): if _is_scipy_csc(data):
return _from_scipy_csc(data, missing, threads, feature_names, feature_types) return _from_scipy_csc(
data, missing, threads, feature_names, feature_types, data_split_mode
)
if _is_scipy_coo(data): if _is_scipy_coo(data):
return _from_scipy_csr( return _from_scipy_csr(
data.tocsr(), missing, threads, feature_names, feature_types data.tocsr(),
missing,
threads,
feature_names,
feature_types,
data_split_mode,
) )
if _is_np_array_like(data): if _is_np_array_like(data):
return _from_numpy_array( return _from_numpy_array(
@ -1043,9 +1071,13 @@ def dispatch_data_backend(
if _is_uri(data): if _is_uri(data):
return _from_uri(data, missing, feature_names, feature_types, data_split_mode) return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
if _is_list(data): if _is_list(data):
return _from_list(data, missing, threads, feature_names, feature_types) return _from_list(
data, missing, threads, feature_names, feature_types, data_split_mode
)
if _is_tuple(data): if _is_tuple(data):
return _from_tuple(data, missing, threads, feature_names, feature_types) return _from_tuple(
data, missing, threads, feature_names, feature_types, data_split_mode
)
if _is_arrow(data): if _is_arrow(data):
data = _arrow_transform(data) data = _arrow_transform(data)
if _is_pandas_series(data): if _is_pandas_series(data):
@ -1054,7 +1086,13 @@ def dispatch_data_backend(
data = pd.DataFrame(data) data = pd.DataFrame(data)
if _is_pandas_df(data): if _is_pandas_df(data):
return _from_pandas_df( return _from_pandas_df(
data, enable_categorical, missing, threads, feature_names, feature_types data,
enable_categorical,
missing,
threads,
feature_names,
feature_types,
data_split_mode,
) )
if _is_cudf_df(data) or _is_cudf_ser(data): if _is_cudf_df(data) or _is_cudf_ser(data):
return _from_cudf_df( return _from_cudf_df(

View File

@ -10,6 +10,7 @@ import os
import platform import platform
import socket import socket
import sys import sys
import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from io import StringIO from io import StringIO
@ -34,6 +35,7 @@ import pytest
from scipy import sparse from scipy import sparse
import xgboost as xgb import xgboost as xgb
from xgboost import RabitTracker
from xgboost.core import ArrayLike from xgboost.core import ArrayLike
from xgboost.sklearn import SklObjective from xgboost.sklearn import SklObjective
from xgboost.testing.data import ( from xgboost.testing.data import (
@ -938,3 +940,22 @@ def load_agaricus(path: str) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
def project_root(path: str) -> str: def project_root(path: str) -> str:
return normpath(os.path.join(demo_dir(path), os.path.pardir)) return normpath(os.path.join(demo_dir(path), os.path.pardir))
def run_with_rabit(world_size: int, test_fn: Callable) -> None:
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
tracker.start(world_size)
def run_worker(rabit_env: Dict[str, Union[str, int]]) -> None:
with xgb.collective.CommunicatorContext(**rabit_env):
test_fn()
workers = []
for _ in range(world_size):
worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
tracker.join()

View File

@ -445,8 +445,11 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
auto config = Json::Load(StringView{c_json_config}); auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config); float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0); auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads)); *out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
API_END(); API_END();
} }
@ -481,8 +484,11 @@ XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char
auto config = Json::Load(StringView{c_json_config}); auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config); float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0)); auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads)); *out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
API_END(); API_END();
} }
@ -722,6 +728,15 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulon
API_END(); API_END();
} }
XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out);
*out = static_cast<xgboost::bst_ulong>(p_m->Info().data_split_mode);
API_END();
}
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config, XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
xgboost::bst_ulong *out_indptr, unsigned *out_indices, xgboost::bst_ulong *out_indptr, unsigned *out_indices,
float *out_data) { float *out_data) {

View File

@ -61,6 +61,7 @@ class RabitCommunicator : public Communicator {
auto const total_size = per_rank * GetWorldSize(); auto const total_size = per_rank * GetWorldSize();
auto const index = per_rank * GetRank(); auto const index = per_rank * GetRank();
std::string result(total_size, '\0'); std::string result(total_size, '\0');
result.replace(index, per_rank, input);
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank); rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
return result; return result;
} }
@ -71,7 +72,8 @@ class RabitCommunicator : public Communicator {
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul); auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
auto const begin_index = auto const begin_index =
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul); std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
auto const size_prev_slice = GetRank() == 0 ? 0 : all_sizes[GetRank() - 1]; auto const size_prev_slice =
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
std::string result(total_size, '\0'); std::string result(total_size, '\0');
result.replace(begin_index, size_node_slice, input); result.replace(begin_index, size_node_slice, input);

View File

@ -635,22 +635,39 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
} }
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) { void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
if (size != 0 && this->num_col_ != 0) { if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) {
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns."; CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
CHECK(info); CHECK(info);
} }
if (!std::strcmp(key, "feature_type")) { if (!std::strcmp(key, "feature_type")) {
feature_type_names.clear(); feature_type_names.clear();
auto& h_feature_types = feature_types.HostVector();
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
auto elem = info[i]; auto elem = info[i];
feature_type_names.emplace_back(elem); feature_type_names.emplace_back(elem);
} }
if (IsColumnSplit()) {
feature_type_names = collective::AllgatherStrings(feature_type_names);
CHECK_EQ(feature_type_names.size(), num_col_)
<< "Length of " << key << " must be equal to number of columns.";
}
auto& h_feature_types = feature_types.HostVector();
LoadFeatureType(feature_type_names, &h_feature_types); LoadFeatureType(feature_type_names, &h_feature_types);
} else if (!std::strcmp(key, "feature_name")) { } else if (!std::strcmp(key, "feature_name")) {
feature_names.clear(); if (IsColumnSplit()) {
for (size_t i = 0; i < size; ++i) { std::vector<std::string> local_feature_names{};
feature_names.emplace_back(info[i]); auto const rank = collective::GetRank();
for (std::size_t i = 0; i < size; ++i) {
auto elem = std::to_string(rank) + "." + info[i];
local_feature_names.emplace_back(elem);
}
feature_names = collective::AllgatherStrings(local_feature_names);
CHECK_EQ(feature_names.size(), num_col_)
<< "Length of " << key << " must be equal to number of columns.";
} else {
feature_names.clear();
for (size_t i = 0; i < size; ++i) {
feature_names.emplace_back(info[i]);
}
} }
} else { } else {
LOG(FATAL) << "Unknown feature info name: " << key; LOG(FATAL) << "Unknown feature info name: " << key;

View File

@ -75,7 +75,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
} }
void SimpleDMatrix::ReindexFeatures(Context const* ctx) { void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
if (info_.IsColumnSplit()) { if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) {
auto const cols = collective::Allgather(info_.num_col_); auto const cols = collective::Allgather(info_.num_col_);
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul); auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
if (offset == 0) { if (offset == 0) {

View File

@ -108,6 +108,7 @@ TEST(CAPI, XGDMatrixCreateFromCSR) {
Json::Dump(data_arr, &sdata); Json::Dump(data_arr, &sdata);
Json config{Object{}}; Json config{Object{}};
config["missing"] = Number{std::numeric_limits<float>::quiet_NaN()}; config["missing"] = Number{std::numeric_limits<float>::quiet_NaN()};
config["data_split_mode"] = Integer{static_cast<int64_t>(DataSplitMode::kCol)};
Json::Dump(config, &sconfig); Json::Dump(config, &sconfig);
DMatrixHandle handle; DMatrixHandle handle;
@ -120,6 +121,8 @@ TEST(CAPI, XGDMatrixCreateFromCSR) {
ASSERT_EQ(n, 3); ASSERT_EQ(n, 3);
ASSERT_EQ(XGDMatrixNumNonMissing(handle, &n), 0); ASSERT_EQ(XGDMatrixNumNonMissing(handle, &n), 0);
ASSERT_EQ(n, 3); ASSERT_EQ(n, 3);
ASSERT_EQ(XGDMatrixDataSplitMode(handle, &n), 0);
ASSERT_EQ(n, static_cast<int64_t>(DataSplitMode::kCol));
std::shared_ptr<xgboost::DMatrix> *pp_fmat = std::shared_ptr<xgboost::DMatrix> *pp_fmat =
static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle); static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);

View File

@ -74,6 +74,49 @@ TEST(MetaInfo, GetSetFeature) {
// Other conditions are tested in `SaveLoadBinary`. // Other conditions are tested in `SaveLoadBinary`.
} }
namespace {
void VerifyGetSetFeatureColumnSplit() {
xgboost::MetaInfo info;
info.data_split_mode = DataSplitMode::kCol;
auto const world_size = collective::GetWorldSize();
auto constexpr kCols{2};
std::vector<std::string> types{u8"float", u8"c"};
std::vector<char const *> c_types(kCols);
std::transform(types.cbegin(), types.cend(), c_types.begin(),
[](auto const &str) { return str.c_str(); });
info.num_col_ = kCols;
EXPECT_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()), dmlc::Error);
info.num_col_ = kCols * world_size;
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()));
std::vector<std::string> expected_type_names{u8"float", u8"c", u8"float",
u8"c", u8"float", u8"c"};
EXPECT_EQ(info.feature_type_names, expected_type_names);
std::vector<xgboost::FeatureType> expected_types{
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical,
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical,
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical};
EXPECT_EQ(info.feature_types.HostVector(), expected_types);
std::vector<std::string> names{u8"feature0", u8"feature1"};
std::vector<char const *> c_names(kCols);
std::transform(names.cbegin(), names.cend(), c_names.begin(),
[](auto const &str) { return str.c_str(); });
info.num_col_ = kCols;
EXPECT_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()), dmlc::Error);
info.num_col_ = kCols * world_size;
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()));
std::vector<std::string> expected_names{u8"0.feature0", u8"0.feature1", u8"1.feature0",
u8"1.feature1", u8"2.feature0", u8"2.feature1"};
EXPECT_EQ(info.feature_names, expected_names);
}
} // anonymous namespace
TEST(MetaInfo, GetSetFeatureColumnSplit) {
auto constexpr kWorldSize{3};
RunWithInMemoryCommunicator(kWorldSize, VerifyGetSetFeatureColumnSplit);
}
TEST(MetaInfo, SaveLoadBinary) { TEST(MetaInfo, SaveLoadBinary) {
xgboost::MetaInfo info; xgboost::MetaInfo info;
xgboost::Context ctx; xgboost::Context ctx;

View File

@ -1,4 +1,5 @@
import os import os
import sys
import tempfile import tempfile
import numpy as np import numpy as np
@ -9,6 +10,7 @@ from scipy.sparse import csr_matrix, rand
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.core import DataSplitMode
from xgboost.testing.data import np_dtypes from xgboost.testing.data import np_dtypes
rng = np.random.RandomState(1) rng = np.random.RandomState(1)
@ -467,3 +469,97 @@ class TestDMatrix:
m0 = xgb.DMatrix(orig) m0 = xgb.DMatrix(orig)
m1 = xgb.DMatrix(x) m1 = xgb.DMatrix(x)
assert tm.predictor_equal(m0, m1) assert tm.predictor_equal(m0, m1)
class TestDMatrixColumnSplit:
def test_numpy(self):
def verify_numpy():
data = np.random.randn(5, 5)
dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL)
assert dm.num_row() == 5
assert dm.num_col() == 5 * xgb.collective.get_world_size()
assert dm.feature_names is None
assert dm.feature_types is None
tm.run_with_rabit(world_size=3, test_fn=verify_numpy)
def test_numpy_feature_names(self):
def verify_numpy_feature_names():
world_size = xgb.collective.get_world_size()
data = np.random.randn(5, 5)
feature_names = [f'feature{x}' for x in range(5)]
feature_types = ['float'] * 5
dm = xgb.DMatrix(data, feature_names=feature_names, feature_types=feature_types,
data_split_mode=DataSplitMode.COL)
assert dm.num_row() == 5
assert dm.num_col() == 5 * world_size
assert len(dm.feature_names) == 5 * world_size
assert len(dm.feature_types) == 5 * world_size
tm.run_with_rabit(world_size=3, test_fn=verify_numpy_feature_names)
def test_csr(self):
def verify_csr():
indptr = np.array([0, 2, 3, 6])
indices = np.array([0, 2, 2, 0, 1, 2])
data = np.array([1, 2, 3, 4, 5, 6])
X = scipy.sparse.csr_matrix((data, indices, indptr), shape=(3, 3))
dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL)
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_csr)
def test_csc(self):
def verify_csc():
row = np.array([0, 2, 2, 0, 1, 2])
col = np.array([0, 0, 1, 2, 2, 2])
data = np.array([1, 2, 3, 4, 5, 6])
X = scipy.sparse.csc_matrix((data, (row, col)), shape=(3, 3))
dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL)
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_csc)
def test_coo(self):
def verify_coo():
row = np.array([0, 2, 2, 0, 1, 2])
col = np.array([0, 0, 1, 2, 2, 2])
data = np.array([1, 2, 3, 4, 5, 6])
X = scipy.sparse.coo_matrix((data, (row, col)), shape=(3, 3))
dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL)
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_coo)
def test_list(self):
def verify_list():
data = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20],
[21, 22, 23, 24, 25]
]
dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL)
assert dm.num_row() == 5
assert dm.num_col() == 5 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_list)
def test_tuple(self):
def verify_tuple():
data = (
(1, 2, 3, 4, 5),
(6, 7, 8, 9, 10),
(11, 12, 13, 14, 15),
(16, 17, 18, 19, 20),
(21, 22, 23, 24, 25)
)
dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL)
assert dm.num_row() == 5
assert dm.num_col() == 5 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_tuple)

View File

@ -1,4 +1,5 @@
import os import os
import sys
import unittest import unittest
import numpy as np import numpy as np
@ -6,6 +7,7 @@ import pytest
import xgboost as xgb import xgboost as xgb
from xgboost import testing as tm from xgboost import testing as tm
from xgboost.core import DataSplitMode
try: try:
import pandas as pd import pandas as pd
@ -97,3 +99,17 @@ class TestArrowTable:
y_np_low = dtrain.get_float_info("label_lower_bound") y_np_low = dtrain.get_float_info("label_lower_bound")
np.testing.assert_equal(y_np_up, y_upper_bound.to_pandas().values) np.testing.assert_equal(y_np_up, y_upper_bound.to_pandas().values)
np.testing.assert_equal(y_np_low, y_lower_bound.to_pandas().values) np.testing.assert_equal(y_np_low, y_lower_bound.to_pandas().values)
class TestArrowTableColumnSplit:
def test_arrow_table(self):
def verify_arrow_table():
df = pd.DataFrame(
[[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
)
table = pa.Table.from_pandas(df)
dm = xgb.DMatrix(table, data_split_mode=DataSplitMode.COL)
assert dm.num_row() == 2
assert dm.num_col() == 4 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_arrow_table)