Support numpy vertical split (#9365)

This commit is contained in:
edumugi 2023-07-08 07:18:12 +02:00 committed by GitHub
parent 59787b23af
commit c3124813e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 3 deletions

View File

@ -197,6 +197,7 @@ def _from_numpy_array(
nthread: int,
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
"""Initialize data from a 2-D numpy matrix."""
_check_data_shape(data)
@ -205,7 +206,11 @@ def _from_numpy_array(
_check_call(
_LIB.XGDMatrixCreateFromDense(
_array_interface(data),
make_jcargs(missing=float(missing), nthread=int(nthread)),
make_jcargs(
missing=float(missing),
nthread=int(nthread),
data_split_mode=int(data_split_mode),
),
ctypes.byref(handle),
)
)
@ -1046,7 +1051,9 @@ def dispatch_data_backend(
data.tocsr(), missing, threads, feature_names, feature_types
)
if _is_numpy_array(data):
return _from_numpy_array(data, missing, threads, feature_names, feature_types)
return _from_numpy_array(
data, missing, threads, feature_names, feature_types, data_split_mode
)
if _is_uri(data):
return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
if _is_list(data):

View File

@ -463,8 +463,11 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data,
auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
API_END();
}