Support column-wise data split with in-memory inputs (#9628)

---------

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Rong Ou 2023-10-16 21:16:39 -07:00 committed by GitHub
parent 4d1607eefd
commit da6803b75b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 307 additions and 27 deletions

View File

@ -144,9 +144,7 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
* See :doc:`/tutorials/input_format` for more info.
* \endverbatim
* - silent (optional): Whether to print message during loading. Default to true.
* - data_split_mode (optional): Whether to split by row or column. In distributed mode, the
* file is split accordingly; otherwise this is only an indicator on how the file was split
* beforehand. Default to row.
* - data_split_mode (optional): Whether the file was split by row or column beforehand for distributed computing. Default to row.
* \param out a loaded data matrix
* \return 0 when success, -1 when failure happens
*/
@ -174,6 +172,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t *indptr, const unsigned *indic
* \param config JSON encoded configuration. Required values are:
* - missing: Which value to represent missing value.
* - nthread (optional): Number of threads used for initializing DMatrix.
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
* \param out created dmatrix
* \return 0 when success, -1 when failure happens
*/
@ -186,6 +185,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
* \param config JSON encoded configuration. Required values are:
* - missing: Which value to represent missing value.
* - nthread (optional): Number of threads used for initializing DMatrix.
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
* \param out created dmatrix
* \return 0 when success, -1 when failure happens
*/
@ -200,6 +200,7 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *config, DMatr
* \param config JSON encoded configuration. Supported values are:
* - missing: Which value to represent missing value.
* - nthread (optional): Number of threads used for initializing DMatrix.
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
* \param out created dmatrix
* \return 0 when success, -1 when failure happens
*/
@ -266,6 +267,7 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data,
* \param config JSON encoded configuration. Required values are:
* - missing: Which value to represent missing value.
* - nthread (optional): Number of threads used for initializing DMatrix.
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
* \param out created dmatrix
* \return 0 when success, -1 when failure happens
*/
@ -278,6 +280,7 @@ XGB_DLL int XGDMatrixCreateFromCudaColumnar(char const *data, char const *config
* \param config JSON encoded configuration. Required values are:
* - missing: Which value to represent missing value.
* - nthread (optional): Number of threads used for initializing DMatrix.
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
* \param out created dmatrix
* \return 0 when success, -1 when failure happens
*/
@ -790,6 +793,16 @@ XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, bst_ulong *out);
*/
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
/*!
* \brief Get the data split mode from DMatrix.
*
* \param handle the handle to the DMatrix
* \param out The output of the data split mode
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out);
/**
* \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a
* quantized DMatrix, quantized values are returned instead.

View File

@ -303,14 +303,14 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
def _validate_feature_info(
feature_info: Sequence[str], n_features: int, name: str
feature_info: Sequence[str], n_features: int, is_column_split: bool, name: str
) -> List[str]:
if isinstance(feature_info, str) or not isinstance(feature_info, Sequence):
raise TypeError(
f"Expecting a sequence of strings for {name}, got: {type(feature_info)}"
)
feature_info = list(feature_info)
if len(feature_info) != n_features and n_features != 0:
if len(feature_info) != n_features and n_features != 0 and not is_column_split:
msg = (
f"{name} must have the same length as the number of data columns, ",
f"expected {n_features}, got {len(feature_info)}",
@ -1231,6 +1231,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
_check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret)))
return ret.value
def data_split_mode(self) -> DataSplitMode:
"""Get the data split mode of the DMatrix.
.. versionadded:: 2.1.0
"""
ret = c_bst_ulong()
_check_call(_LIB.XGDMatrixDataSplitMode(self.handle, ctypes.byref(ret)))
return DataSplitMode(ret.value)
def slice(
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
) -> "DMatrix":
@ -1298,7 +1308,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
# validate feature name
feature_names = _validate_feature_info(
feature_names, self.num_col(), "feature names"
feature_names,
self.num_col(),
self.data_split_mode() == DataSplitMode.COL,
"feature names",
)
if len(feature_names) != len(set(feature_names)):
values, counts = np.unique(
@ -1371,7 +1384,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
return
feature_types = _validate_feature_info(
feature_types, self.num_col(), "feature types"
feature_types,
self.num_col(),
self.data_split_mode() == DataSplitMode.COL,
"feature types",
)
feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types]

View File

@ -107,6 +107,7 @@ def _from_scipy_csr(
nthread: int,
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
"""Initialize data from a CSR matrix."""
@ -118,7 +119,11 @@ def _from_scipy_csr(
_array_interface(data.indices),
_array_interface(data.data),
c_bst_ulong(data.shape[1]),
make_jcargs(missing=float(missing), nthread=int(nthread)),
make_jcargs(
missing=float(missing),
nthread=int(nthread),
data_split_mode=int(data_split_mode),
),
ctypes.byref(handle),
)
)
@ -139,6 +144,7 @@ def _from_scipy_csc(
nthread: int,
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
"""Initialize data from a CSC matrix."""
handle = ctypes.c_void_p()
@ -149,7 +155,11 @@ def _from_scipy_csc(
_array_interface(data.indices),
_array_interface(data.data),
c_bst_ulong(data.shape[0]),
make_jcargs(missing=float(missing), nthread=int(nthread)),
make_jcargs(
missing=float(missing),
nthread=int(nthread),
data_split_mode=int(data_split_mode),
),
ctypes.byref(handle),
)
)
@ -518,11 +528,14 @@ def _from_pandas_df(
nthread: int,
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
data, feature_names, feature_types = _transform_pandas_df(
data, enable_categorical, feature_names, feature_types
)
return _from_numpy_array(data, missing, nthread, feature_names, feature_types)
return _from_numpy_array(
data, missing, nthread, feature_names, feature_types, data_split_mode
)
def _is_pandas_series(data: DataType) -> bool:
@ -970,10 +983,13 @@ def _from_list(
n_threads: int,
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
array = np.array(data)
_check_data_shape(data)
return _from_numpy_array(array, missing, n_threads, feature_names, feature_types)
return _from_numpy_array(
array, missing, n_threads, feature_names, feature_types, data_split_mode
)
def _is_tuple(data: DataType) -> bool:
@ -986,8 +1002,11 @@ def _from_tuple(
n_threads: int,
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
return _from_list(data, missing, n_threads, feature_names, feature_types)
return _from_list(
data, missing, n_threads, feature_names, feature_types, data_split_mode
)
def _is_iter(data: DataType) -> bool:
@ -1029,12 +1048,21 @@ def dispatch_data_backend(
if not _is_cudf_ser(data) and not _is_pandas_series(data):
_check_data_shape(data)
if _is_scipy_csr(data):
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
return _from_scipy_csr(
data, missing, threads, feature_names, feature_types, data_split_mode
)
if _is_scipy_csc(data):
return _from_scipy_csc(data, missing, threads, feature_names, feature_types)
return _from_scipy_csc(
data, missing, threads, feature_names, feature_types, data_split_mode
)
if _is_scipy_coo(data):
return _from_scipy_csr(
data.tocsr(), missing, threads, feature_names, feature_types
data.tocsr(),
missing,
threads,
feature_names,
feature_types,
data_split_mode,
)
if _is_np_array_like(data):
return _from_numpy_array(
@ -1043,9 +1071,13 @@ def dispatch_data_backend(
if _is_uri(data):
return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
if _is_list(data):
return _from_list(data, missing, threads, feature_names, feature_types)
return _from_list(
data, missing, threads, feature_names, feature_types, data_split_mode
)
if _is_tuple(data):
return _from_tuple(data, missing, threads, feature_names, feature_types)
return _from_tuple(
data, missing, threads, feature_names, feature_types, data_split_mode
)
if _is_arrow(data):
data = _arrow_transform(data)
if _is_pandas_series(data):
@ -1054,7 +1086,13 @@ def dispatch_data_backend(
data = pd.DataFrame(data)
if _is_pandas_df(data):
return _from_pandas_df(
data, enable_categorical, missing, threads, feature_names, feature_types
data,
enable_categorical,
missing,
threads,
feature_names,
feature_types,
data_split_mode,
)
if _is_cudf_df(data) or _is_cudf_ser(data):
return _from_cudf_df(

View File

@ -10,6 +10,7 @@ import os
import platform
import socket
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from io import StringIO
@ -34,6 +35,7 @@ import pytest
from scipy import sparse
import xgboost as xgb
from xgboost import RabitTracker
from xgboost.core import ArrayLike
from xgboost.sklearn import SklObjective
from xgboost.testing.data import (
@ -938,3 +940,22 @@ def load_agaricus(path: str) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
def project_root(path: str) -> str:
return normpath(os.path.join(demo_dir(path), os.path.pardir))
def run_with_rabit(world_size: int, test_fn: Callable) -> None:
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
tracker.start(world_size)
def run_worker(rabit_env: Dict[str, Union[str, int]]) -> None:
with xgb.collective.CommunicatorContext(**rabit_env):
test_fn()
workers = []
for _ in range(world_size):
worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
tracker.join()

View File

@ -445,8 +445,11 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
API_END();
}
@ -481,8 +484,11 @@ XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char
auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
API_END();
}
@ -722,6 +728,15 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulon
API_END();
}
XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out);
*out = static_cast<xgboost::bst_ulong>(p_m->Info().data_split_mode);
API_END();
}
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
float *out_data) {

View File

@ -61,6 +61,7 @@ class RabitCommunicator : public Communicator {
auto const total_size = per_rank * GetWorldSize();
auto const index = per_rank * GetRank();
std::string result(total_size, '\0');
result.replace(index, per_rank, input);
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
return result;
}
@ -71,7 +72,8 @@ class RabitCommunicator : public Communicator {
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
auto const begin_index =
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
auto const size_prev_slice = GetRank() == 0 ? 0 : all_sizes[GetRank() - 1];
auto const size_prev_slice =
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
std::string result(total_size, '\0');
result.replace(begin_index, size_node_slice, input);

View File

@ -635,22 +635,39 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
}
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
if (size != 0 && this->num_col_ != 0) {
if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) {
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
CHECK(info);
}
if (!std::strcmp(key, "feature_type")) {
feature_type_names.clear();
auto& h_feature_types = feature_types.HostVector();
for (size_t i = 0; i < size; ++i) {
auto elem = info[i];
feature_type_names.emplace_back(elem);
}
if (IsColumnSplit()) {
feature_type_names = collective::AllgatherStrings(feature_type_names);
CHECK_EQ(feature_type_names.size(), num_col_)
<< "Length of " << key << " must be equal to number of columns.";
}
auto& h_feature_types = feature_types.HostVector();
LoadFeatureType(feature_type_names, &h_feature_types);
} else if (!std::strcmp(key, "feature_name")) {
feature_names.clear();
for (size_t i = 0; i < size; ++i) {
feature_names.emplace_back(info[i]);
if (IsColumnSplit()) {
std::vector<std::string> local_feature_names{};
auto const rank = collective::GetRank();
for (std::size_t i = 0; i < size; ++i) {
auto elem = std::to_string(rank) + "." + info[i];
local_feature_names.emplace_back(elem);
}
feature_names = collective::AllgatherStrings(local_feature_names);
CHECK_EQ(feature_names.size(), num_col_)
<< "Length of " << key << " must be equal to number of columns.";
} else {
feature_names.clear();
for (size_t i = 0; i < size; ++i) {
feature_names.emplace_back(info[i]);
}
}
} else {
LOG(FATAL) << "Unknown feature info name: " << key;

View File

@ -75,7 +75,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
}
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
if (info_.IsColumnSplit()) {
if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) {
auto const cols = collective::Allgather(info_.num_col_);
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
if (offset == 0) {

View File

@ -108,6 +108,7 @@ TEST(CAPI, XGDMatrixCreateFromCSR) {
Json::Dump(data_arr, &sdata);
Json config{Object{}};
config["missing"] = Number{std::numeric_limits<float>::quiet_NaN()};
config["data_split_mode"] = Integer{static_cast<int64_t>(DataSplitMode::kCol)};
Json::Dump(config, &sconfig);
DMatrixHandle handle;
@ -120,6 +121,8 @@ TEST(CAPI, XGDMatrixCreateFromCSR) {
ASSERT_EQ(n, 3);
ASSERT_EQ(XGDMatrixNumNonMissing(handle, &n), 0);
ASSERT_EQ(n, 3);
ASSERT_EQ(XGDMatrixDataSplitMode(handle, &n), 0);
ASSERT_EQ(n, static_cast<int64_t>(DataSplitMode::kCol));
std::shared_ptr<xgboost::DMatrix> *pp_fmat =
static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);

View File

@ -74,6 +74,49 @@ TEST(MetaInfo, GetSetFeature) {
// Other conditions are tested in `SaveLoadBinary`.
}
namespace {
void VerifyGetSetFeatureColumnSplit() {
xgboost::MetaInfo info;
info.data_split_mode = DataSplitMode::kCol;
auto const world_size = collective::GetWorldSize();
auto constexpr kCols{2};
std::vector<std::string> types{u8"float", u8"c"};
std::vector<char const *> c_types(kCols);
std::transform(types.cbegin(), types.cend(), c_types.begin(),
[](auto const &str) { return str.c_str(); });
info.num_col_ = kCols;
EXPECT_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()), dmlc::Error);
info.num_col_ = kCols * world_size;
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()));
std::vector<std::string> expected_type_names{u8"float", u8"c", u8"float",
u8"c", u8"float", u8"c"};
EXPECT_EQ(info.feature_type_names, expected_type_names);
std::vector<xgboost::FeatureType> expected_types{
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical,
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical,
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical};
EXPECT_EQ(info.feature_types.HostVector(), expected_types);
std::vector<std::string> names{u8"feature0", u8"feature1"};
std::vector<char const *> c_names(kCols);
std::transform(names.cbegin(), names.cend(), c_names.begin(),
[](auto const &str) { return str.c_str(); });
info.num_col_ = kCols;
EXPECT_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()), dmlc::Error);
info.num_col_ = kCols * world_size;
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()));
std::vector<std::string> expected_names{u8"0.feature0", u8"0.feature1", u8"1.feature0",
u8"1.feature1", u8"2.feature0", u8"2.feature1"};
EXPECT_EQ(info.feature_names, expected_names);
}
} // anonymous namespace
TEST(MetaInfo, GetSetFeatureColumnSplit) {
auto constexpr kWorldSize{3};
RunWithInMemoryCommunicator(kWorldSize, VerifyGetSetFeatureColumnSplit);
}
TEST(MetaInfo, SaveLoadBinary) {
xgboost::MetaInfo info;
xgboost::Context ctx;

View File

@ -1,4 +1,5 @@
import os
import sys
import tempfile
import numpy as np
@ -9,6 +10,7 @@ from scipy.sparse import csr_matrix, rand
import xgboost as xgb
from xgboost import testing as tm
from xgboost.core import DataSplitMode
from xgboost.testing.data import np_dtypes
rng = np.random.RandomState(1)
@ -467,3 +469,97 @@ class TestDMatrix:
m0 = xgb.DMatrix(orig)
m1 = xgb.DMatrix(x)
assert tm.predictor_equal(m0, m1)
class TestDMatrixColumnSplit:
def test_numpy(self):
def verify_numpy():
data = np.random.randn(5, 5)
dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL)
assert dm.num_row() == 5
assert dm.num_col() == 5 * xgb.collective.get_world_size()
assert dm.feature_names is None
assert dm.feature_types is None
tm.run_with_rabit(world_size=3, test_fn=verify_numpy)
def test_numpy_feature_names(self):
def verify_numpy_feature_names():
world_size = xgb.collective.get_world_size()
data = np.random.randn(5, 5)
feature_names = [f'feature{x}' for x in range(5)]
feature_types = ['float'] * 5
dm = xgb.DMatrix(data, feature_names=feature_names, feature_types=feature_types,
data_split_mode=DataSplitMode.COL)
assert dm.num_row() == 5
assert dm.num_col() == 5 * world_size
assert len(dm.feature_names) == 5 * world_size
assert len(dm.feature_types) == 5 * world_size
tm.run_with_rabit(world_size=3, test_fn=verify_numpy_feature_names)
def test_csr(self):
def verify_csr():
indptr = np.array([0, 2, 3, 6])
indices = np.array([0, 2, 2, 0, 1, 2])
data = np.array([1, 2, 3, 4, 5, 6])
X = scipy.sparse.csr_matrix((data, indices, indptr), shape=(3, 3))
dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL)
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_csr)
def test_csc(self):
def verify_csc():
row = np.array([0, 2, 2, 0, 1, 2])
col = np.array([0, 0, 1, 2, 2, 2])
data = np.array([1, 2, 3, 4, 5, 6])
X = scipy.sparse.csc_matrix((data, (row, col)), shape=(3, 3))
dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL)
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_csc)
def test_coo(self):
def verify_coo():
row = np.array([0, 2, 2, 0, 1, 2])
col = np.array([0, 0, 1, 2, 2, 2])
data = np.array([1, 2, 3, 4, 5, 6])
X = scipy.sparse.coo_matrix((data, (row, col)), shape=(3, 3))
dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL)
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_coo)
def test_list(self):
def verify_list():
data = [
[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20],
[21, 22, 23, 24, 25]
]
dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL)
assert dm.num_row() == 5
assert dm.num_col() == 5 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_list)
def test_tuple(self):
def verify_tuple():
data = (
(1, 2, 3, 4, 5),
(6, 7, 8, 9, 10),
(11, 12, 13, 14, 15),
(16, 17, 18, 19, 20),
(21, 22, 23, 24, 25)
)
dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL)
assert dm.num_row() == 5
assert dm.num_col() == 5 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_tuple)

View File

@ -1,4 +1,5 @@
import os
import sys
import unittest
import numpy as np
@ -6,6 +7,7 @@ import pytest
import xgboost as xgb
from xgboost import testing as tm
from xgboost.core import DataSplitMode
try:
import pandas as pd
@ -97,3 +99,17 @@ class TestArrowTable:
y_np_low = dtrain.get_float_info("label_lower_bound")
np.testing.assert_equal(y_np_up, y_upper_bound.to_pandas().values)
np.testing.assert_equal(y_np_low, y_lower_bound.to_pandas().values)
class TestArrowTableColumnSplit:
def test_arrow_table(self):
def verify_arrow_table():
df = pd.DataFrame(
[[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
)
table = pa.Table.from_pandas(df)
dm = xgb.DMatrix(table, data_split_mode=DataSplitMode.COL)
assert dm.num_row() == 2
assert dm.num_col() == 4 * xgb.collective.get_world_size()
tm.run_with_rabit(world_size=3, test_fn=verify_arrow_table)