Support column-wise data split with in-memory inputs (#9628)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
parent
4d1607eefd
commit
da6803b75b
@ -144,9 +144,7 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
|
||||
* See :doc:`/tutorials/input_format` for more info.
|
||||
* \endverbatim
|
||||
* - silent (optional): Whether to print message during loading. Default to true.
|
||||
* - data_split_mode (optional): Whether to split by row or column. In distributed mode, the
|
||||
* file is split accordingly; otherwise this is only an indicator on how the file was split
|
||||
* beforehand. Default to row.
|
||||
* - data_split_mode (optional): Whether the file was split by row or column beforehand for distributed computing. Default to row.
|
||||
* \param out a loaded data matrix
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
@ -174,6 +172,7 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t *indptr, const unsigned *indic
|
||||
* \param config JSON encoded configuration. Required values are:
|
||||
* - missing: Which value to represent missing value.
|
||||
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
|
||||
* \param out created dmatrix
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
@ -186,6 +185,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
|
||||
* \param config JSON encoded configuration. Required values are:
|
||||
* - missing: Which value to represent missing value.
|
||||
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
|
||||
* \param out created dmatrix
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
@ -200,6 +200,7 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data, char const *config, DMatr
|
||||
* \param config JSON encoded configuration. Supported values are:
|
||||
* - missing: Which value to represent missing value.
|
||||
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
|
||||
* \param out created dmatrix
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
@ -266,6 +267,7 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data,
|
||||
* \param config JSON encoded configuration. Required values are:
|
||||
* - missing: Which value to represent missing value.
|
||||
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
|
||||
* \param out created dmatrix
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
@ -278,6 +280,7 @@ XGB_DLL int XGDMatrixCreateFromCudaColumnar(char const *data, char const *config
|
||||
* \param config JSON encoded configuration. Required values are:
|
||||
* - missing: Which value to represent missing value.
|
||||
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||
* - data_split_mode (optional): Whether the data was split by row or column beforehand. Default to row.
|
||||
* \param out created dmatrix
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
@ -790,6 +793,16 @@ XGB_DLL int XGDMatrixNumCol(DMatrixHandle handle, bst_ulong *out);
|
||||
*/
|
||||
XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle handle, bst_ulong *out);
|
||||
|
||||
/*!
|
||||
* \brief Get the data split mode from DMatrix.
|
||||
*
|
||||
* \param handle the handle to the DMatrix
|
||||
* \param out The output of the data split mode
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out);
|
||||
|
||||
/**
|
||||
* \brief Get the predictors from DMatrix as CSR matrix for testing. If this is a
|
||||
* quantized DMatrix, quantized values are returned instead.
|
||||
|
||||
@ -303,14 +303,14 @@ def _check_distributed_params(kwargs: Dict[str, Any]) -> None:
|
||||
|
||||
|
||||
def _validate_feature_info(
|
||||
feature_info: Sequence[str], n_features: int, name: str
|
||||
feature_info: Sequence[str], n_features: int, is_column_split: bool, name: str
|
||||
) -> List[str]:
|
||||
if isinstance(feature_info, str) or not isinstance(feature_info, Sequence):
|
||||
raise TypeError(
|
||||
f"Expecting a sequence of strings for {name}, got: {type(feature_info)}"
|
||||
)
|
||||
feature_info = list(feature_info)
|
||||
if len(feature_info) != n_features and n_features != 0:
|
||||
if len(feature_info) != n_features and n_features != 0 and not is_column_split:
|
||||
msg = (
|
||||
f"{name} must have the same length as the number of data columns, ",
|
||||
f"expected {n_features}, got {len(feature_info)}",
|
||||
@ -1231,6 +1231,16 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
_check_call(_LIB.XGDMatrixNumNonMissing(self.handle, ctypes.byref(ret)))
|
||||
return ret.value
|
||||
|
||||
def data_split_mode(self) -> DataSplitMode:
|
||||
"""Get the data split mode of the DMatrix.
|
||||
|
||||
.. versionadded:: 2.1.0
|
||||
|
||||
"""
|
||||
ret = c_bst_ulong()
|
||||
_check_call(_LIB.XGDMatrixDataSplitMode(self.handle, ctypes.byref(ret)))
|
||||
return DataSplitMode(ret.value)
|
||||
|
||||
def slice(
|
||||
self, rindex: Union[List[int], np.ndarray], allow_groups: bool = False
|
||||
) -> "DMatrix":
|
||||
@ -1298,7 +1308,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
|
||||
# validate feature name
|
||||
feature_names = _validate_feature_info(
|
||||
feature_names, self.num_col(), "feature names"
|
||||
feature_names,
|
||||
self.num_col(),
|
||||
self.data_split_mode() == DataSplitMode.COL,
|
||||
"feature names",
|
||||
)
|
||||
if len(feature_names) != len(set(feature_names)):
|
||||
values, counts = np.unique(
|
||||
@ -1371,7 +1384,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
return
|
||||
|
||||
feature_types = _validate_feature_info(
|
||||
feature_types, self.num_col(), "feature types"
|
||||
feature_types,
|
||||
self.num_col(),
|
||||
self.data_split_mode() == DataSplitMode.COL,
|
||||
"feature types",
|
||||
)
|
||||
|
||||
feature_types_bytes = [bytes(f, encoding="utf-8") for f in feature_types]
|
||||
|
||||
@ -107,6 +107,7 @@ def _from_scipy_csr(
|
||||
nthread: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
"""Initialize data from a CSR matrix."""
|
||||
|
||||
@ -118,7 +119,11 @@ def _from_scipy_csr(
|
||||
_array_interface(data.indices),
|
||||
_array_interface(data.data),
|
||||
c_bst_ulong(data.shape[1]),
|
||||
make_jcargs(missing=float(missing), nthread=int(nthread)),
|
||||
make_jcargs(
|
||||
missing=float(missing),
|
||||
nthread=int(nthread),
|
||||
data_split_mode=int(data_split_mode),
|
||||
),
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
@ -139,6 +144,7 @@ def _from_scipy_csc(
|
||||
nthread: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
"""Initialize data from a CSC matrix."""
|
||||
handle = ctypes.c_void_p()
|
||||
@ -149,7 +155,11 @@ def _from_scipy_csc(
|
||||
_array_interface(data.indices),
|
||||
_array_interface(data.data),
|
||||
c_bst_ulong(data.shape[0]),
|
||||
make_jcargs(missing=float(missing), nthread=int(nthread)),
|
||||
make_jcargs(
|
||||
missing=float(missing),
|
||||
nthread=int(nthread),
|
||||
data_split_mode=int(data_split_mode),
|
||||
),
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
@ -518,11 +528,14 @@ def _from_pandas_df(
|
||||
nthread: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
data, feature_names, feature_types = _transform_pandas_df(
|
||||
data, enable_categorical, feature_names, feature_types
|
||||
)
|
||||
return _from_numpy_array(data, missing, nthread, feature_names, feature_types)
|
||||
return _from_numpy_array(
|
||||
data, missing, nthread, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
|
||||
|
||||
def _is_pandas_series(data: DataType) -> bool:
|
||||
@ -970,10 +983,13 @@ def _from_list(
|
||||
n_threads: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
array = np.array(data)
|
||||
_check_data_shape(data)
|
||||
return _from_numpy_array(array, missing, n_threads, feature_names, feature_types)
|
||||
return _from_numpy_array(
|
||||
array, missing, n_threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
|
||||
|
||||
def _is_tuple(data: DataType) -> bool:
|
||||
@ -986,8 +1002,11 @@ def _from_tuple(
|
||||
n_threads: int,
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
return _from_list(data, missing, n_threads, feature_names, feature_types)
|
||||
return _from_list(
|
||||
data, missing, n_threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
|
||||
|
||||
def _is_iter(data: DataType) -> bool:
|
||||
@ -1029,12 +1048,21 @@ def dispatch_data_backend(
|
||||
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
||||
_check_data_shape(data)
|
||||
if _is_scipy_csr(data):
|
||||
return _from_scipy_csr(data, missing, threads, feature_names, feature_types)
|
||||
return _from_scipy_csr(
|
||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
if _is_scipy_csc(data):
|
||||
return _from_scipy_csc(data, missing, threads, feature_names, feature_types)
|
||||
return _from_scipy_csc(
|
||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
if _is_scipy_coo(data):
|
||||
return _from_scipy_csr(
|
||||
data.tocsr(), missing, threads, feature_names, feature_types
|
||||
data.tocsr(),
|
||||
missing,
|
||||
threads,
|
||||
feature_names,
|
||||
feature_types,
|
||||
data_split_mode,
|
||||
)
|
||||
if _is_np_array_like(data):
|
||||
return _from_numpy_array(
|
||||
@ -1043,9 +1071,13 @@ def dispatch_data_backend(
|
||||
if _is_uri(data):
|
||||
return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
|
||||
if _is_list(data):
|
||||
return _from_list(data, missing, threads, feature_names, feature_types)
|
||||
return _from_list(
|
||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
if _is_tuple(data):
|
||||
return _from_tuple(data, missing, threads, feature_names, feature_types)
|
||||
return _from_tuple(
|
||||
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||
)
|
||||
if _is_arrow(data):
|
||||
data = _arrow_transform(data)
|
||||
if _is_pandas_series(data):
|
||||
@ -1054,7 +1086,13 @@ def dispatch_data_backend(
|
||||
data = pd.DataFrame(data)
|
||||
if _is_pandas_df(data):
|
||||
return _from_pandas_df(
|
||||
data, enable_categorical, missing, threads, feature_names, feature_types
|
||||
data,
|
||||
enable_categorical,
|
||||
missing,
|
||||
threads,
|
||||
feature_names,
|
||||
feature_types,
|
||||
data_split_mode,
|
||||
)
|
||||
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||
return _from_cudf_df(
|
||||
|
||||
@ -10,6 +10,7 @@ import os
|
||||
import platform
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
@ -34,6 +35,7 @@ import pytest
|
||||
from scipy import sparse
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import RabitTracker
|
||||
from xgboost.core import ArrayLike
|
||||
from xgboost.sklearn import SklObjective
|
||||
from xgboost.testing.data import (
|
||||
@ -938,3 +940,22 @@ def load_agaricus(path: str) -> Tuple[xgb.DMatrix, xgb.DMatrix]:
|
||||
|
||||
def project_root(path: str) -> str:
|
||||
return normpath(os.path.join(demo_dir(path), os.path.pardir))
|
||||
|
||||
|
||||
def run_with_rabit(world_size: int, test_fn: Callable) -> None:
|
||||
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size)
|
||||
tracker.start(world_size)
|
||||
|
||||
def run_worker(rabit_env: Dict[str, Union[str, int]]) -> None:
|
||||
with xgb.collective.CommunicatorContext(**rabit_env):
|
||||
test_fn()
|
||||
|
||||
workers = []
|
||||
for _ in range(world_size):
|
||||
worker = threading.Thread(target=run_worker, args=(tracker.worker_envs(),))
|
||||
workers.append(worker)
|
||||
worker.start()
|
||||
for worker in workers:
|
||||
worker.join()
|
||||
|
||||
tracker.join()
|
||||
|
||||
@ -445,8 +445,11 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
|
||||
auto data_split_mode =
|
||||
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
*out = new std::shared_ptr<DMatrix>(
|
||||
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -481,8 +484,11 @@ XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||
auto data_split_mode =
|
||||
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
*out = new std::shared_ptr<DMatrix>(
|
||||
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
|
||||
|
||||
API_END();
|
||||
}
|
||||
@ -722,6 +728,15 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulon
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().data_split_mode);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
|
||||
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
|
||||
float *out_data) {
|
||||
|
||||
@ -61,6 +61,7 @@ class RabitCommunicator : public Communicator {
|
||||
auto const total_size = per_rank * GetWorldSize();
|
||||
auto const index = per_rank * GetRank();
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(index, per_rank, input);
|
||||
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
|
||||
return result;
|
||||
}
|
||||
@ -71,7 +72,8 @@ class RabitCommunicator : public Communicator {
|
||||
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
|
||||
auto const begin_index =
|
||||
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
|
||||
auto const size_prev_slice = GetRank() == 0 ? 0 : all_sizes[GetRank() - 1];
|
||||
auto const size_prev_slice =
|
||||
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
|
||||
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(begin_index, size_node_slice, input);
|
||||
|
||||
@ -635,22 +635,39 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
||||
}
|
||||
|
||||
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
|
||||
if (size != 0 && this->num_col_ != 0) {
|
||||
if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) {
|
||||
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
|
||||
CHECK(info);
|
||||
}
|
||||
if (!std::strcmp(key, "feature_type")) {
|
||||
feature_type_names.clear();
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto elem = info[i];
|
||||
feature_type_names.emplace_back(elem);
|
||||
}
|
||||
if (IsColumnSplit()) {
|
||||
feature_type_names = collective::AllgatherStrings(feature_type_names);
|
||||
CHECK_EQ(feature_type_names.size(), num_col_)
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
}
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
LoadFeatureType(feature_type_names, &h_feature_types);
|
||||
} else if (!std::strcmp(key, "feature_name")) {
|
||||
feature_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_names.emplace_back(info[i]);
|
||||
if (IsColumnSplit()) {
|
||||
std::vector<std::string> local_feature_names{};
|
||||
auto const rank = collective::GetRank();
|
||||
for (std::size_t i = 0; i < size; ++i) {
|
||||
auto elem = std::to_string(rank) + "." + info[i];
|
||||
local_feature_names.emplace_back(elem);
|
||||
}
|
||||
feature_names = collective::AllgatherStrings(local_feature_names);
|
||||
CHECK_EQ(feature_names.size(), num_col_)
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
} else {
|
||||
feature_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_names.emplace_back(info[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown feature info name: " << key;
|
||||
|
||||
@ -75,7 +75,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
}
|
||||
|
||||
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
|
||||
if (info_.IsColumnSplit()) {
|
||||
if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) {
|
||||
auto const cols = collective::Allgather(info_.num_col_);
|
||||
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
|
||||
if (offset == 0) {
|
||||
|
||||
@ -108,6 +108,7 @@ TEST(CAPI, XGDMatrixCreateFromCSR) {
|
||||
Json::Dump(data_arr, &sdata);
|
||||
Json config{Object{}};
|
||||
config["missing"] = Number{std::numeric_limits<float>::quiet_NaN()};
|
||||
config["data_split_mode"] = Integer{static_cast<int64_t>(DataSplitMode::kCol)};
|
||||
Json::Dump(config, &sconfig);
|
||||
|
||||
DMatrixHandle handle;
|
||||
@ -120,6 +121,8 @@ TEST(CAPI, XGDMatrixCreateFromCSR) {
|
||||
ASSERT_EQ(n, 3);
|
||||
ASSERT_EQ(XGDMatrixNumNonMissing(handle, &n), 0);
|
||||
ASSERT_EQ(n, 3);
|
||||
ASSERT_EQ(XGDMatrixDataSplitMode(handle, &n), 0);
|
||||
ASSERT_EQ(n, static_cast<int64_t>(DataSplitMode::kCol));
|
||||
|
||||
std::shared_ptr<xgboost::DMatrix> *pp_fmat =
|
||||
static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
|
||||
@ -74,6 +74,49 @@ TEST(MetaInfo, GetSetFeature) {
|
||||
// Other conditions are tested in `SaveLoadBinary`.
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyGetSetFeatureColumnSplit() {
|
||||
xgboost::MetaInfo info;
|
||||
info.data_split_mode = DataSplitMode::kCol;
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
|
||||
auto constexpr kCols{2};
|
||||
std::vector<std::string> types{u8"float", u8"c"};
|
||||
std::vector<char const *> c_types(kCols);
|
||||
std::transform(types.cbegin(), types.cend(), c_types.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
info.num_col_ = kCols;
|
||||
EXPECT_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()), dmlc::Error);
|
||||
info.num_col_ = kCols * world_size;
|
||||
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_type", c_types.data(), c_types.size()));
|
||||
std::vector<std::string> expected_type_names{u8"float", u8"c", u8"float",
|
||||
u8"c", u8"float", u8"c"};
|
||||
EXPECT_EQ(info.feature_type_names, expected_type_names);
|
||||
std::vector<xgboost::FeatureType> expected_types{
|
||||
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical,
|
||||
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical,
|
||||
xgboost::FeatureType::kNumerical, xgboost::FeatureType::kCategorical};
|
||||
EXPECT_EQ(info.feature_types.HostVector(), expected_types);
|
||||
|
||||
std::vector<std::string> names{u8"feature0", u8"feature1"};
|
||||
std::vector<char const *> c_names(kCols);
|
||||
std::transform(names.cbegin(), names.cend(), c_names.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
info.num_col_ = kCols;
|
||||
EXPECT_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()), dmlc::Error);
|
||||
info.num_col_ = kCols * world_size;
|
||||
EXPECT_NO_THROW(info.SetFeatureInfo(u8"feature_name", c_names.data(), c_names.size()));
|
||||
std::vector<std::string> expected_names{u8"0.feature0", u8"0.feature1", u8"1.feature0",
|
||||
u8"1.feature1", u8"2.feature0", u8"2.feature1"};
|
||||
EXPECT_EQ(info.feature_names, expected_names);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(MetaInfo, GetSetFeatureColumnSplit) {
|
||||
auto constexpr kWorldSize{3};
|
||||
RunWithInMemoryCommunicator(kWorldSize, VerifyGetSetFeatureColumnSplit);
|
||||
}
|
||||
|
||||
TEST(MetaInfo, SaveLoadBinary) {
|
||||
xgboost::MetaInfo info;
|
||||
xgboost::Context ctx;
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
@ -9,6 +10,7 @@ from scipy.sparse import csr_matrix, rand
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.core import DataSplitMode
|
||||
from xgboost.testing.data import np_dtypes
|
||||
|
||||
rng = np.random.RandomState(1)
|
||||
@ -467,3 +469,97 @@ class TestDMatrix:
|
||||
m0 = xgb.DMatrix(orig)
|
||||
m1 = xgb.DMatrix(x)
|
||||
assert tm.predictor_equal(m0, m1)
|
||||
|
||||
|
||||
class TestDMatrixColumnSplit:
|
||||
def test_numpy(self):
|
||||
def verify_numpy():
|
||||
data = np.random.randn(5, 5)
|
||||
dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL)
|
||||
assert dm.num_row() == 5
|
||||
assert dm.num_col() == 5 * xgb.collective.get_world_size()
|
||||
assert dm.feature_names is None
|
||||
assert dm.feature_types is None
|
||||
|
||||
tm.run_with_rabit(world_size=3, test_fn=verify_numpy)
|
||||
|
||||
def test_numpy_feature_names(self):
|
||||
def verify_numpy_feature_names():
|
||||
world_size = xgb.collective.get_world_size()
|
||||
data = np.random.randn(5, 5)
|
||||
feature_names = [f'feature{x}' for x in range(5)]
|
||||
feature_types = ['float'] * 5
|
||||
dm = xgb.DMatrix(data, feature_names=feature_names, feature_types=feature_types,
|
||||
data_split_mode=DataSplitMode.COL)
|
||||
assert dm.num_row() == 5
|
||||
assert dm.num_col() == 5 * world_size
|
||||
assert len(dm.feature_names) == 5 * world_size
|
||||
assert len(dm.feature_types) == 5 * world_size
|
||||
|
||||
tm.run_with_rabit(world_size=3, test_fn=verify_numpy_feature_names)
|
||||
|
||||
def test_csr(self):
|
||||
def verify_csr():
|
||||
indptr = np.array([0, 2, 3, 6])
|
||||
indices = np.array([0, 2, 2, 0, 1, 2])
|
||||
data = np.array([1, 2, 3, 4, 5, 6])
|
||||
X = scipy.sparse.csr_matrix((data, indices, indptr), shape=(3, 3))
|
||||
dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL)
|
||||
assert dtrain.num_row() == 3
|
||||
assert dtrain.num_col() == 3 * xgb.collective.get_world_size()
|
||||
|
||||
tm.run_with_rabit(world_size=3, test_fn=verify_csr)
|
||||
|
||||
def test_csc(self):
|
||||
def verify_csc():
|
||||
row = np.array([0, 2, 2, 0, 1, 2])
|
||||
col = np.array([0, 0, 1, 2, 2, 2])
|
||||
data = np.array([1, 2, 3, 4, 5, 6])
|
||||
X = scipy.sparse.csc_matrix((data, (row, col)), shape=(3, 3))
|
||||
dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL)
|
||||
assert dtrain.num_row() == 3
|
||||
assert dtrain.num_col() == 3 * xgb.collective.get_world_size()
|
||||
|
||||
tm.run_with_rabit(world_size=3, test_fn=verify_csc)
|
||||
|
||||
def test_coo(self):
|
||||
def verify_coo():
|
||||
row = np.array([0, 2, 2, 0, 1, 2])
|
||||
col = np.array([0, 0, 1, 2, 2, 2])
|
||||
data = np.array([1, 2, 3, 4, 5, 6])
|
||||
X = scipy.sparse.coo_matrix((data, (row, col)), shape=(3, 3))
|
||||
dtrain = xgb.DMatrix(X, data_split_mode=DataSplitMode.COL)
|
||||
assert dtrain.num_row() == 3
|
||||
assert dtrain.num_col() == 3 * xgb.collective.get_world_size()
|
||||
|
||||
tm.run_with_rabit(world_size=3, test_fn=verify_coo)
|
||||
|
||||
def test_list(self):
|
||||
def verify_list():
|
||||
data = [
|
||||
[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10],
|
||||
[11, 12, 13, 14, 15],
|
||||
[16, 17, 18, 19, 20],
|
||||
[21, 22, 23, 24, 25]
|
||||
]
|
||||
dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL)
|
||||
assert dm.num_row() == 5
|
||||
assert dm.num_col() == 5 * xgb.collective.get_world_size()
|
||||
|
||||
tm.run_with_rabit(world_size=3, test_fn=verify_list)
|
||||
|
||||
def test_tuple(self):
|
||||
def verify_tuple():
|
||||
data = (
|
||||
(1, 2, 3, 4, 5),
|
||||
(6, 7, 8, 9, 10),
|
||||
(11, 12, 13, 14, 15),
|
||||
(16, 17, 18, 19, 20),
|
||||
(21, 22, 23, 24, 25)
|
||||
)
|
||||
dm = xgb.DMatrix(data, data_split_mode=DataSplitMode.COL)
|
||||
assert dm.num_row() == 5
|
||||
assert dm.num_col() == 5 * xgb.collective.get_world_size()
|
||||
|
||||
tm.run_with_rabit(world_size=3, test_fn=verify_tuple)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@ -6,6 +7,7 @@ import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.core import DataSplitMode
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
@ -97,3 +99,17 @@ class TestArrowTable:
|
||||
y_np_low = dtrain.get_float_info("label_lower_bound")
|
||||
np.testing.assert_equal(y_np_up, y_upper_bound.to_pandas().values)
|
||||
np.testing.assert_equal(y_np_low, y_lower_bound.to_pandas().values)
|
||||
|
||||
|
||||
class TestArrowTableColumnSplit:
|
||||
def test_arrow_table(self):
|
||||
def verify_arrow_table():
|
||||
df = pd.DataFrame(
|
||||
[[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
|
||||
)
|
||||
table = pa.Table.from_pandas(df)
|
||||
dm = xgb.DMatrix(table, data_split_mode=DataSplitMode.COL)
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 4 * xgb.collective.get_world_size()
|
||||
|
||||
tm.run_with_rabit(world_size=3, test_fn=verify_arrow_table)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user