Add data split mode to DMatrix MetaInfo (#8568)
This commit is contained in:
@@ -10,6 +10,7 @@ import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from enum import IntEnum, unique
|
||||
from functools import wraps
|
||||
from inspect import Parameter, signature
|
||||
from typing import (
|
||||
@@ -608,6 +609,13 @@ def require_keyword_args(
|
||||
_deprecate_positional_args = require_keyword_args(False)
|
||||
|
||||
|
||||
@unique
|
||||
class DataSplitMode(IntEnum):
|
||||
"""Supported data split mode for DMatrix."""
|
||||
ROW = 0
|
||||
COL = 1
|
||||
|
||||
|
||||
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
|
||||
"""Data Matrix used in XGBoost.
|
||||
|
||||
@@ -635,6 +643,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
label_upper_bound: Optional[ArrayLike] = None,
|
||||
feature_weights: Optional[ArrayLike] = None,
|
||||
enable_categorical: bool = False,
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> None:
|
||||
"""Parameters
|
||||
----------
|
||||
@@ -728,6 +737,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
enable_categorical=enable_categorical,
|
||||
data_split_mode=data_split_mode,
|
||||
)
|
||||
assert handle is not None
|
||||
self.handle = handle
|
||||
@@ -1332,6 +1342,7 @@ class QuantileDMatrix(DMatrix):
|
||||
label_upper_bound: Optional[ArrayLike] = None,
|
||||
feature_weights: Optional[ArrayLike] = None,
|
||||
enable_categorical: bool = False,
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> None:
|
||||
self.max_bin: int = max_bin if max_bin is not None else 256
|
||||
self.missing = missing if missing is not None else np.nan
|
||||
|
||||
@@ -23,6 +23,7 @@ from .compat import DataFrame, lazy_isinstance
|
||||
from .core import (
|
||||
_LIB,
|
||||
DataIter,
|
||||
DataSplitMode,
|
||||
DMatrix,
|
||||
_check_call,
|
||||
_cuda_array_interface,
|
||||
@@ -865,13 +866,17 @@ def _from_uri(
|
||||
missing: Optional[FloatCompatible],
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
_warn_unused_missing(data, missing)
|
||||
handle = ctypes.c_void_p()
|
||||
data = os.fspath(os.path.expanduser(data))
|
||||
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
|
||||
ctypes.c_int(1),
|
||||
ctypes.byref(handle)))
|
||||
args = {
|
||||
"uri": str(data),
|
||||
"data_split_mode": int(data_split_mode),
|
||||
}
|
||||
config = bytes(json.dumps(args), "utf-8")
|
||||
_check_call(_LIB.XGDMatrixCreateFromURI(config, ctypes.byref(handle)))
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
@@ -938,6 +943,7 @@ def dispatch_data_backend(
|
||||
feature_names: Optional[FeatureNames],
|
||||
feature_types: Optional[FeatureTypes],
|
||||
enable_categorical: bool = False,
|
||||
data_split_mode: DataSplitMode = DataSplitMode.ROW,
|
||||
) -> DispatchedDataBackendReturnType:
|
||||
'''Dispatch data for DMatrix.'''
|
||||
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
||||
@@ -953,7 +959,7 @@ def dispatch_data_backend(
|
||||
if _is_numpy_array(data):
|
||||
return _from_numpy_array(data, missing, threads, feature_names, feature_types)
|
||||
if _is_uri(data):
|
||||
return _from_uri(data, missing, feature_names, feature_types)
|
||||
return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
|
||||
if _is_list(data):
|
||||
return _from_list(data, missing, threads, feature_names, feature_types)
|
||||
if _is_tuple(data):
|
||||
|
||||
Reference in New Issue
Block a user