Support numpy vertical split (#9365)
This commit is contained in:
parent
59787b23af
commit
c3124813e8
@ -197,6 +197,7 @@ def _from_numpy_array(
|
|||||||
nthread: int,
|
nthread: int,
|
||||||
feature_names: Optional[FeatureNames],
|
feature_names: Optional[FeatureNames],
|
||||||
feature_types: Optional[FeatureTypes],
|
feature_types: Optional[FeatureTypes],
|
||||||
|
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||||
) -> DispatchedDataBackendReturnType:
|
) -> DispatchedDataBackendReturnType:
|
||||||
"""Initialize data from a 2-D numpy matrix."""
|
"""Initialize data from a 2-D numpy matrix."""
|
||||||
_check_data_shape(data)
|
_check_data_shape(data)
|
||||||
@ -205,7 +206,11 @@ def _from_numpy_array(
|
|||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGDMatrixCreateFromDense(
|
_LIB.XGDMatrixCreateFromDense(
|
||||||
_array_interface(data),
|
_array_interface(data),
|
||||||
make_jcargs(missing=float(missing), nthread=int(nthread)),
|
make_jcargs(
|
||||||
|
missing=float(missing),
|
||||||
|
nthread=int(nthread),
|
||||||
|
data_split_mode=int(data_split_mode),
|
||||||
|
),
|
||||||
ctypes.byref(handle),
|
ctypes.byref(handle),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -1046,7 +1051,9 @@ def dispatch_data_backend(
|
|||||||
data.tocsr(), missing, threads, feature_names, feature_types
|
data.tocsr(), missing, threads, feature_names, feature_types
|
||||||
)
|
)
|
||||||
if _is_numpy_array(data):
|
if _is_numpy_array(data):
|
||||||
return _from_numpy_array(data, missing, threads, feature_names, feature_types)
|
return _from_numpy_array(
|
||||||
|
data, missing, threads, feature_names, feature_types, data_split_mode
|
||||||
|
)
|
||||||
if _is_uri(data):
|
if _is_uri(data):
|
||||||
return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
|
return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
|
||||||
if _is_list(data):
|
if _is_list(data):
|
||||||
|
|||||||
@ -463,8 +463,11 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data,
|
|||||||
auto config = Json::Load(StringView{c_json_config});
|
auto config = Json::Load(StringView{c_json_config});
|
||||||
float missing = GetMissing(config);
|
float missing = GetMissing(config);
|
||||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
|
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
|
||||||
|
auto data_split_mode =
|
||||||
|
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
|
||||||
xgboost_CHECK_C_ARG_PTR(out);
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
*out = new std::shared_ptr<DMatrix>(
|
||||||
|
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user