Add data split mode to DMatrix MetaInfo (#8568)

This commit is contained in:
Rong Ou
2022-12-25 04:37:37 -08:00
committed by GitHub
parent 77b069c25d
commit 3ceeb8c61c
20 changed files with 113 additions and 103 deletions

View File

@@ -10,6 +10,7 @@ import sys
import warnings
from abc import ABC, abstractmethod
from collections.abc import Mapping
from enum import IntEnum, unique
from functools import wraps
from inspect import Parameter, signature
from typing import (
@@ -608,6 +609,13 @@ def require_keyword_args(
_deprecate_positional_args = require_keyword_args(False)
@unique
class DataSplitMode(IntEnum):
"""Supported data split mode for DMatrix."""
ROW = 0
COL = 1
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
"""Data Matrix used in XGBoost.
@@ -635,6 +643,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
label_upper_bound: Optional[ArrayLike] = None,
feature_weights: Optional[ArrayLike] = None,
enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> None:
"""Parameters
----------
@@ -728,6 +737,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
feature_names=feature_names,
feature_types=feature_types,
enable_categorical=enable_categorical,
data_split_mode=data_split_mode,
)
assert handle is not None
self.handle = handle
@@ -1332,6 +1342,7 @@ class QuantileDMatrix(DMatrix):
label_upper_bound: Optional[ArrayLike] = None,
feature_weights: Optional[ArrayLike] = None,
enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> None:
self.max_bin: int = max_bin if max_bin is not None else 256
self.missing = missing if missing is not None else np.nan

View File

@@ -23,6 +23,7 @@ from .compat import DataFrame, lazy_isinstance
from .core import (
_LIB,
DataIter,
DataSplitMode,
DMatrix,
_check_call,
_cuda_array_interface,
@@ -865,13 +866,17 @@ def _from_uri(
missing: Optional[FloatCompatible],
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
_warn_unused_missing(data, missing)
handle = ctypes.c_void_p()
data = os.fspath(os.path.expanduser(data))
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(data),
ctypes.c_int(1),
ctypes.byref(handle)))
args = {
"uri": str(data),
"data_split_mode": int(data_split_mode),
}
config = bytes(json.dumps(args), "utf-8")
_check_call(_LIB.XGDMatrixCreateFromURI(config, ctypes.byref(handle)))
return handle, feature_names, feature_types
@@ -938,6 +943,7 @@ def dispatch_data_backend(
feature_names: Optional[FeatureNames],
feature_types: Optional[FeatureTypes],
enable_categorical: bool = False,
data_split_mode: DataSplitMode = DataSplitMode.ROW,
) -> DispatchedDataBackendReturnType:
'''Dispatch data for DMatrix.'''
if not _is_cudf_ser(data) and not _is_pandas_series(data):
@@ -953,7 +959,7 @@ def dispatch_data_backend(
if _is_numpy_array(data):
return _from_numpy_array(data, missing, threads, feature_names, feature_types)
if _is_uri(data):
return _from_uri(data, missing, feature_names, feature_types)
return _from_uri(data, missing, feature_names, feature_types, data_split_mode)
if _is_list(data):
return _from_list(data, missing, threads, feature_names, feature_types)
if _is_tuple(data):