diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index ce2d8bf43..d28b5098b 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -144,9 +144,7 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle * See :doc:`/tutorials/input_format` for more info. * \endverbatim * - silent (optional): Whether to print message during loading. Default to true. - * - data_split_mode (optional): Whether to split by row or column. In distributed mode, the - * file is split accordingly; otherwise this is only an indicator on how the file was split - * beforehand. Default to row. + * - data_split_mode (optional): Whether the file was split by row or column beforehand for distributed computing. Default to row. * \param out a loaded data matrix * \return 0 when success, -1 when failure happens */ @@ -174,6 +172,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t *indptr, const unsigned *indic * \param config JSON encoded configuration. Required values are: * - missing: Which value to represent missing value. * - nthread (optional): Number of threads used for initializing DMatrix. + * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row. * \param out created dmatrix * \return 0 when success, -1 when failure happens */ @@ -186,6 +185,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char * \param config JSON encoded configuration. Required values are: * - missing: Which value to represent missing value. * - nthread (optional): Number of threads used for initializing DMatrix. + * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row. * \param out created dmatrix * \return 0 when success, -1 when failure happens */ @@ -200,6 +200,7 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *config, DMatr * \param config JSON encoded configuration. Supported values are: * - missing: Which value to represent missing value. * - nthread (optional): Number of threads used for initializing DMatrix. + * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row. * \param out created dmatrix * \return 0 when success, -1 when failure happens */ @@ -266,6 +267,7 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, * \param config JSON encoded configuration. Required values are: * - missing: Which value to represent missing value. * - nthread (optional): Number of threads used for initializing DMatrix. + * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row. * \param out created dmatrix * \return 0 when success, -1 when failure happens */ @@ -278,6 +280,7 @@ XGB_DLL int XGDMatrixCreateFromCudaColumnar(char const *data, char const *config * \param config JSON encoded configuration. Required values are: * - missing: Which value to represent missing value. * - nthread (optional): Number of threads used for initializing DMatrix. + * - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row. * \param out created dmatrix * \return 0 when success, -1 when failure happens */ @@ -790,6 +793,16 @@ XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, bst_ulong *out); */ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out); +/*! + * \brief Get the data split mode from DMatrix. + * + * \param handle the handle to the DMatrix + * \param out The output of the data split mode + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out); + /** * \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a * quantized DMatrix, quantized values are returned instead. diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 91c6bbd85..648851b31 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -303,14 +303,14 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None: def _validate_feature_info( - feature_info: Sequence[str], n_features: int, name: str + feature_info: Sequence[str], n_features: int, is_column_split: bool, name: str ) -> List[str]: if isinstance(feature_info, str) or not isinstance(feature_info, Sequence): raise TypeError( f"Expecting a sequence of strings for {name}, got: {type(feature_info)}" ) feature_info = list(feature_info) - if len(feature_info) != n_features and n_features != 0: + if len(feature_info) != n_features and n_features != 0 and not is_column_split: msg = ( f"{name} must have the same length as the number of data columns, ", f"expected {n_features}, got {len(feature_info)}", @@ -1231,6 +1231,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m _check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret))) return ret.value + def data_split_mode(self) -> DataSplitMode: + """Get the data split mode of the DMatrix. + + .. versionadded:: 2.1.0 + + """ + ret = c_bst_ulong() + _check_call(_LIB.XGDMatrixDataSplitMode(self.handle, ctypes.byref(ret))) + return DataSplitMode(ret.value) + def slice( self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False ) -> "DMatrix": @@ -1298,7 +1308,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m # validate feature name feature_names = _validate_feature_info( - feature_names, self.num_col(), "feature names" + feature_names, + self.num_col(), + self.data_split_mode() == DataSplitMode.COL, + "feature names", ) if len(feature_names) != len(set(feature_names)): values, counts = np.unique( @@ -1371,7 +1384,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m return feature_types = _validate_feature_info( - feature_types, self.num_col(), "feature types" + feature_types, + self.num_col(), + self.data_split_mode() == DataSplitMode.COL, + "feature types", ) feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types] diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index bfdb21c80..49287d817 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -107,6 +107,7 @@ def _from_scipy_csr( nthread: int, feature_names: Optional[FeatureNames], feature_types: Optional[FeatureTypes], + data_split_mode: DataSplitMode = DataSplitMode.ROW, ) -> DispatchedDataBackendReturnType: """Initialize data from a CSR matrix.""" @@ -118,7 +119,11 @@ def _from_scipy_csr( _array_interface(data.indices), _array_interface(data.data), c_bst_ulong(data.shape[1]), - make_jcargs(missing=float(missing), nthread=int(nthread)), + make_jcargs( + missing=float(missing), + nthread=int(nthread), + data_split_mode=int(data_split_mode), + ), ctypes.byref(handle), ) ) @@ -139,6 +144,7 @@ def _from_scipy_csc( nthread: int, feature_names: Optional[FeatureNames], feature_types: Optional[FeatureTypes], + data_split_mode: DataSplitMode = DataSplitMode.ROW, ) -> DispatchedDataBackendReturnType: """Initialize data from a CSC matrix.""" handle = ctypes.c_void_p() @@ -149,7 +155,11 @@ def _from_scipy_csc( _array_interface(data.indices), _array_interface(data.data), c_bst_ulong(data.shape[0]), - make_jcargs(missing=float(missing), nthread=int(nthread)), + make_jcargs( + missing=float(missing), + nthread=int(nthread), + data_split_mode=int(data_split_mode), + ), ctypes.byref(handle), ) ) @@ -518,11 +528,14 @@ def _from_pandas_df( nthread: int, feature_names: Optional[FeatureNames], feature_types: Optional[FeatureTypes], + data_split_mode: DataSplitMode = DataSplitMode.ROW, ) -> DispatchedDataBackendReturnType: data, feature_names, feature_types = _transform_pandas_df( data, enable_categorical, feature_names, feature_types ) - return _from_numpy_array(data, missing, nthread, feature_names, feature_types) + return _from_numpy_array( + data, missing, nthread, feature_names, feature_types, data_split_mode + ) def _is_pandas_series(data: DataType) -> bool: @@ -970,10 +983,13 @@ def _from_list( n_threads: int, feature_names: Optional[FeatureNames], feature_types: Optional[FeatureTypes], + data_split_mode: DataSplitMode = DataSplitMode.ROW, ) -> DispatchedDataBackendReturnType: array = np.array(data) _check_data_shape(data) - return _from_numpy_array(array, missing, n_threads, feature_names, feature_types) + return _from_numpy_array( + array, missing, n_threads, feature_names, feature_types, data_split_mode + ) def _is_tuple(data: DataType) -> bool: @@ -986,8 +1002,11 @@ def _from_tuple( n_threads: int, feature_names: Optional[FeatureNames], feature_types: Optional[FeatureTypes], + data_split_mode: DataSplitMode = DataSplitMode.ROW, ) -> DispatchedDataBackendReturnType: - return _from_list(data, missing, n_threads, feature_names, feature_types) + return _from_list( + data, missing, n_threads, feature_names, feature_types, data_split_mode + ) def _is_iter(data: DataType) -> bool: @@ -1029,12 +1048,21 @@ def dispatch_data_backend( if not _is_cudf_ser(data) and not _is_pandas_series(data): _check_data_shape(data) if _is_scipy_csr(data): - return _from_scipy_csr(data, missing, threads, feature_names, feature_types) + return _from_scipy_csr( + data, missing, threads, feature_names, feature_types, data_split_mode + ) if _is_scipy_csc(data): - return _from_scipy_csc(data, missing, threads, feature_names, feature_types) + return _from_scipy_csc( + data, missing, threads, feature_names, feature_types, data_split_mode + ) if _is_scipy_coo(data): return _from_scipy_csr( - data.tocsr(), missing, threads, feature_names, feature_types + data.tocsr(), + missing, + threads, + feature_names, + feature_types, + data_split_mode, ) if _is_np_array_like(data): return _from_numpy_array( @@ -1043,9 +1071,13 @@ def dispatch_data_backend( if _is_uri(data): return _from_uri(data, missing, feature_names, feature_types, data_split_mode) if _is_list(data): - return _from_list(data, missing, threads, feature_names, feature_types) + return _from_list( + data, missing, threads, feature_names, feature_types, data_split_mode + ) if _is_tuple(data): - return _from_tuple(data, missing, threads, feature_names, feature_types) + return _from_tuple( + data, missing, threads, feature_names, feature_types, data_split_mode + ) if _is_arrow(data): data = _arrow_transform(data) if _is_pandas_series(data): @@ -1054,7 +1086,13 @@ def dispatch_data_backend( data = pd.DataFrame(data) if _is_pandas_df(data): return _from_pandas_df( - data, enable_categorical, missing, threads, feature_names, feature_types + data, + enable_categorical, + missing, + threads, + feature_names, + feature_types, + data_split_mode, ) if _is_cudf_df(data) or _is_cudf_ser(data): return _from_cudf_df( diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 2e0933a43..391f2bf9f 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -10,6 +10,7 @@ import os import platform import socket import sys +import threading from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from io import StringIO @@ -34,6 +35,7 @@ import pytest from scipy import sparse import xgboost as xgb +from xgboost import RabitTracker from xgboost.core import ArrayLike from xgboost.sklearn import SklObjective from xgboost.testing.data import ( @@ -938,3 +940,22 @@ def load_agaricus(path: str) -> Tuple[xgb.DMatrix, xgb.DMatrix]: def project_root(path: str) -> str: return normpath(os.path.join(demo_dir(path), os.path.pardir)) + + +def run_with_rabit(world_size: int, test_fn: Callable) -> None: + tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size) + tracker.start(world_size) + + def run_worker(rabit_env: Dict[str, Union[str, int]]) -> None: + with xgb.collective.CommunicatorContext(**rabit_env): + test_fn() + + workers = [] + for _ in range(world_size): + worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),)) + workers.append(worker) + worker.start() + for worker in workers: + worker.join() + + tracker.join() diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 4fb6d90ff..8975bfb2e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -445,8 +445,11 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char auto config = Json::Load(StringView{c_json_config}); float missing = GetMissing(config); auto n_threads = OptionalArg(config, "nthread", 0); + auto data_split_mode = + static_cast(OptionalArg(config, "data_split_mode", 0)); xgboost_CHECK_C_ARG_PTR(out); - *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); + *out = new std::shared_ptr( + DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode)); API_END(); } @@ -481,8 +484,11 @@ XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char auto config = Json::Load(StringView{c_json_config}); float missing = GetMissing(config); auto n_threads = OptionalArg(config, "nthread", common::OmpGetNumThreads(0)); + auto data_split_mode = + static_cast(OptionalArg(config, "data_split_mode", 0)); xgboost_CHECK_C_ARG_PTR(out); - *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, n_threads)); + *out = new std::shared_ptr( + DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode)); API_END(); } @@ -722,6 +728,15 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulon API_END(); } +XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out) { + API_BEGIN(); + CHECK_HANDLE(); + auto p_m = CastDMatrixHandle(handle); + xgboost_CHECK_C_ARG_PTR(out); + *out = static_cast(p_m->Info().data_split_mode); + API_END(); +} + XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config, xgboost::bst_ulong *out_indptr, unsigned *out_indices, float *out_data) { diff --git a/src/collective/rabit_communicator.h b/src/collective/rabit_communicator.h index 59a4bbbd8..452e9ad9c 100644 --- a/src/collective/rabit_communicator.h +++ b/src/collective/rabit_communicator.h @@ -61,6 +61,7 @@ class RabitCommunicator : public Communicator { auto const total_size = per_rank * GetWorldSize(); auto const index = per_rank * GetRank(); std::string result(total_size, '\0'); + result.replace(index, per_rank, input); rabit::Allgather(result.data(), total_size, index, per_rank, per_rank); return result; } @@ -71,7 +72,8 @@ class RabitCommunicator : public Communicator { auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul); auto const begin_index = std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul); - auto const size_prev_slice = GetRank() == 0 ? 0 : all_sizes[GetRank() - 1]; + auto const size_prev_slice = + GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1]; std::string result(total_size, '\0'); result.replace(begin_index, size_node_slice, input); diff --git a/src/data/data.cc b/src/data/data.cc index 3c190a90b..7e70fff3f 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -635,22 +635,39 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype, } void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) { - if (size != 0 && this->num_col_ != 0) { + if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) { CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns."; CHECK(info); } if (!std::strcmp(key, "feature_type")) { feature_type_names.clear(); - auto& h_feature_types = feature_types.HostVector(); for (size_t i = 0; i < size; ++i) { auto elem = info[i]; feature_type_names.emplace_back(elem); } + if (IsColumnSplit()) { + feature_type_names = collective::AllgatherStrings(feature_type_names); + CHECK_EQ(feature_type_names.size(), num_col_) + << "Length of " << key << " must be equal to number of columns."; + } + auto& h_feature_types = feature_types.HostVector(); LoadFeatureType(feature_type_names, &h_feature_types); } else if (!std::strcmp(key, "feature_name")) { - feature_names.clear(); - for (size_t i = 0; i < size; ++i) { - feature_names.emplace_back(info[i]); + if (IsColumnSplit()) { + std::vector local_feature_names{}; + auto const rank = collective::GetRank(); + for (std::size_t i = 0; i < size; ++i) { + auto elem = std::to_string(rank) + "." + info[i]; + local_feature_names.emplace_back(elem); + } + feature_names = collective::AllgatherStrings(local_feature_names); + CHECK_EQ(feature_names.size(), num_col_) + << "Length of " << key << " must be equal to number of columns."; + } else { + feature_names.clear(); + for (size_t i = 0; i < size; ++i) { + feature_names.emplace_back(info[i]); + } } } else { LOG(FATAL) << "Unknown feature info name: " << key; diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 48e764986..3814d74d2 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -75,7 +75,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) { } void SimpleDMatrix::ReindexFeatures(Context const* ctx) { - if (info_.IsColumnSplit()) { + if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) { auto const cols = collective::Allgather(info_.num_col_); auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul); if (offset == 0) { diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 7fcab199e..4491dee92 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -108,6 +108,7 @@ TEST(CAPI, XGDMatrixCreateFromCSR) { Json::Dump(data_arr, &sdata); Json config{Object{}}; config["missing"] = Number{std::numeric_limits::quiet_NaN()}; + config["data_split_mode"] = Integer{static_cast(DataSplitMode::kCol)}; Json::Dump(config, &sconfig); DMatrixHandle handle; @@ -120,6 +121,8 @@ TEST(CAPI, XGDMatrixCreateFromCSR) { ASSERT_EQ(n, 3); ASSERT_EQ(XGDMatrixNumNonMissing(handle, &n), 0); ASSERT_EQ(n, 3); + ASSERT_EQ(XGDMatrixDataSplitMode(handle, &n), 0); + ASSERT_EQ(n, static_cast(DataSplitMode::kCol)); std::shared_ptr *pp_fmat = static_cast *>(handle); diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 3e96d1919..67c5b39a4 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -74,6 +74,49 @@ TEST(MetaInfo, GetSetFeature) { // Other conditions are tested in `SaveLoadBinary`. } +namespace { +void VerifyGetSetFeatureColumnSplit() { + xgboost::MetaInfo info; + info.data_split_mode = DataSplitMode::kCol; + auto const world_size = collective::GetWorldSize(); + + auto constexpr kCols{2}; + std::vector types{u8"float", u8"c"}; + std::vector c_types(kCols); + std::transform(types.cbegin(), types.cend(), c_types.begin(), + [](auto const &str) { return str.c_str(); }); + info.num_col_ = kCols; + EXPECT_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()), dmlc::Error); + info.num_col_ = kCols * world_size; + EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size())); + std::vector expected_type_names{u8"float", u8"c", u8"float", + u8"c", u8"float", u8"c"}; + EXPECT_EQ(info.feature_type_names, expected_type_names); + std::vector expected_types{ + xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical, + xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical, + xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical}; + EXPECT_EQ(info.feature_types.HostVector(), expected_types); + + std::vector names{u8"feature0", u8"feature1"}; + std::vector c_names(kCols); + std::transform(names.cbegin(), names.cend(), c_names.begin(), + [](auto const &str) { return str.c_str(); }); + info.num_col_ = kCols; + EXPECT_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()), dmlc::Error); + info.num_col_ = kCols * world_size; + EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size())); + std::vector expected_names{u8"0.feature0", u8"0.feature1", u8"1.feature0", + u8"1.feature1", u8"2.feature0", u8"2.feature1"}; + EXPECT_EQ(info.feature_names, expected_names); +} +} // anonymous namespace + +TEST(MetaInfo, GetSetFeatureColumnSplit) { + auto constexpr kWorldSize{3}; + RunWithInMemoryCommunicator(kWorldSize, VerifyGetSetFeatureColumnSplit); +} + TEST(MetaInfo, SaveLoadBinary) { xgboost::MetaInfo info; xgboost::Context ctx; diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 73e2055b7..51bee5669 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -1,4 +1,5 @@ import os +import sys import tempfile import numpy as np @@ -9,6 +10,7 @@ from scipy.sparse import csr_matrix, rand import xgboost as xgb from xgboost import testing as tm +from xgboost.core import DataSplitMode from xgboost.testing.data import np_dtypes rng = np.random.RandomState(1) @@ -467,3 +469,97 @@ class TestDMatrix: m0 = xgb.DMatrix(orig) m1 = xgb.DMatrix(x) assert tm.predictor_equal(m0, m1) + + +class TestDMatrixColumnSplit: + def test_numpy(self): + def verify_numpy(): + data = np.random.randn(5, 5) + dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL) + assert dm.num_row() == 5 + assert dm.num_col() == 5 * xgb.collective.get_world_size() + assert dm.feature_names is None + assert dm.feature_types is None + + tm.run_with_rabit(world_size=3, test_fn=verify_numpy) + + def test_numpy_feature_names(self): + def verify_numpy_feature_names(): + world_size = xgb.collective.get_world_size() + data = np.random.randn(5, 5) + feature_names = [f'feature{x}' for x in range(5)] + feature_types = ['float'] * 5 + dm = xgb.DMatrix(data, feature_names=feature_names, feature_types=feature_types, + data_split_mode=DataSplitMode.COL) + assert dm.num_row() == 5 + assert dm.num_col() == 5 * world_size + assert len(dm.feature_names) == 5 * world_size + assert len(dm.feature_types) == 5 * world_size + + tm.run_with_rabit(world_size=3, test_fn=verify_numpy_feature_names) + + def test_csr(self): + def verify_csr(): + indptr = np.array([0, 2, 3, 6]) + indices = np.array([0, 2, 2, 0, 1, 2]) + data = np.array([1, 2, 3, 4, 5, 6]) + X = scipy.sparse.csr_matrix((data, indices, indptr), shape=(3, 3)) + dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL) + assert dtrain.num_row() == 3 + assert dtrain.num_col() == 3 * xgb.collective.get_world_size() + + tm.run_with_rabit(world_size=3, test_fn=verify_csr) + + def test_csc(self): + def verify_csc(): + row = np.array([0, 2, 2, 0, 1, 2]) + col = np.array([0, 0, 1, 2, 2, 2]) + data = np.array([1, 2, 3, 4, 5, 6]) + X = scipy.sparse.csc_matrix((data, (row, col)), shape=(3, 3)) + dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL) + assert dtrain.num_row() == 3 + assert dtrain.num_col() == 3 * xgb.collective.get_world_size() + + tm.run_with_rabit(world_size=3, test_fn=verify_csc) + + def test_coo(self): + def verify_coo(): + row = np.array([0, 2, 2, 0, 1, 2]) + col = np.array([0, 0, 1, 2, 2, 2]) + data = np.array([1, 2, 3, 4, 5, 6]) + X = scipy.sparse.coo_matrix((data, (row, col)), shape=(3, 3)) + dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL) + assert dtrain.num_row() == 3 + assert dtrain.num_col() == 3 * xgb.collective.get_world_size() + + tm.run_with_rabit(world_size=3, test_fn=verify_coo) + + def test_list(self): + def verify_list(): + data = [ + [1, 2, 3, 4, 5], + [6, 7, 8, 9, 10], + [11, 12, 13, 14, 15], + [16, 17, 18, 19, 20], + [21, 22, 23, 24, 25] + ] + dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL) + assert dm.num_row() == 5 + assert dm.num_col() == 5 * xgb.collective.get_world_size() + + tm.run_with_rabit(world_size=3, test_fn=verify_list) + + def test_tuple(self): + def verify_tuple(): + data = ( + (1, 2, 3, 4, 5), + (6, 7, 8, 9, 10), + (11, 12, 13, 14, 15), + (16, 17, 18, 19, 20), + (21, 22, 23, 24, 25) + ) + dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL) + assert dm.num_row() == 5 + assert dm.num_col() == 5 * xgb.collective.get_world_size() + + tm.run_with_rabit(world_size=3, test_fn=verify_tuple) diff --git a/tests/python/test_with_arrow.py b/tests/python/test_with_arrow.py index 4673a688e..fdc4c7dbe 100644 --- a/tests/python/test_with_arrow.py +++ b/tests/python/test_with_arrow.py @@ -1,4 +1,5 @@ import os +import sys import unittest import numpy as np @@ -6,6 +7,7 @@ import pytest import xgboost as xgb from xgboost import testing as tm +from xgboost.core import DataSplitMode try: import pandas as pd @@ -97,3 +99,17 @@ class TestArrowTable: y_np_low = dtrain.get_float_info("label_lower_bound") np.testing.assert_equal(y_np_up, y_upper_bound.to_pandas().values) np.testing.assert_equal(y_np_low, y_lower_bound.to_pandas().values) + + +class TestArrowTableColumnSplit: + def test_arrow_table(self): + def verify_arrow_table(): + df = pd.DataFrame( + [[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"] + ) + table = pa.Table.from_pandas(df) + dm = xgb.DMatrix(table, data_split_mode=DataSplitMode.COL) + assert dm.num_row() == 2 + assert dm.num_col() == 4 * xgb.collective.get_world_size() + + tm.run_with_rabit(world_size=3, test_fn=verify_arrow_table)