Support more input types for categorical data. (#7220)
* Support more input types for categorical data. * Shorten the type name from "categorical" to "c". * Tests for np/cp array and scipy csr/csc/coo. * Specify the type for feature info.
This commit is contained in:
parent
2942dc68e4
commit
0ed979b096
@ -44,7 +44,8 @@ def make_categorical(
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
# Use builtin categorical data support
|
# Use builtin categorical data support
|
||||||
# Must be pandas DataFrame or cudf DataFrame with categorical data
|
# For scikit-learn interface, the input data must be pandas DataFrame or cudf
|
||||||
|
# DataFrame with categorical features
|
||||||
X, y = make_categorical(100, 10, 4, False)
|
X, y = make_categorical(100, 10, 4, False)
|
||||||
# Specify `enable_categorical` to True.
|
# Specify `enable_categorical` to True.
|
||||||
reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True)
|
reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True)
|
||||||
|
|||||||
@ -83,7 +83,7 @@ class FeatureMap {
|
|||||||
if (!strcmp("q", tname)) return kQuantitive;
|
if (!strcmp("q", tname)) return kQuantitive;
|
||||||
if (!strcmp("int", tname)) return kInteger;
|
if (!strcmp("int", tname)) return kInteger;
|
||||||
if (!strcmp("float", tname)) return kFloat;
|
if (!strcmp("float", tname)) return kFloat;
|
||||||
if (!strcmp("categorical", tname)) return kCategorical;
|
if (!strcmp("c", tname)) return kCategorical;
|
||||||
LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity";
|
LOG(FATAL) << "unknown feature type, use i for indicator and q for quantity";
|
||||||
return kIndicator;
|
return kIndicator;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -518,8 +518,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
base_margin=None,
|
base_margin=None,
|
||||||
missing: Optional[float] = None,
|
missing: Optional[float] = None,
|
||||||
silent=False,
|
silent=False,
|
||||||
feature_names=None,
|
feature_names: Optional[List[str]] = None,
|
||||||
feature_types=None,
|
feature_types: Optional[List[str]] = None,
|
||||||
nthread: Optional[int] = None,
|
nthread: Optional[int] = None,
|
||||||
group=None,
|
group=None,
|
||||||
qid=None,
|
qid=None,
|
||||||
@ -558,8 +558,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
Whether print messages during construction
|
Whether print messages during construction
|
||||||
feature_names : list, optional
|
feature_names : list, optional
|
||||||
Set names for features.
|
Set names for features.
|
||||||
feature_types : list, optional
|
feature_types :
|
||||||
Set types for features.
|
|
||||||
|
Set types for features. When `enable_categorical` is set to `True`, string
|
||||||
|
"c" represents categorical data type.
|
||||||
|
|
||||||
nthread : integer, optional
|
nthread : integer, optional
|
||||||
Number of threads to use for loading data when parallelization is
|
Number of threads to use for loading data when parallelization is
|
||||||
applicable. If -1, uses maximum threads available on the system.
|
applicable. If -1, uses maximum threads available on the system.
|
||||||
@ -577,11 +580,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
Experimental support of specializing for categorical features. Do
|
Experimental support of specializing for categorical features. Do not set to
|
||||||
not set to True unless you are interested in development.
|
True unless you are interested in development. Currently it's only available
|
||||||
Currently it's only available for `gpu_hist` tree method with 1 vs
|
for `gpu_hist` tree method with 1 vs rest (one hot) categorical split. Also,
|
||||||
rest (one hot) categorical split. Also, JSON serialization format,
|
JSON serialization format is required.
|
||||||
`gpu_predictor` and pandas input are required.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if group is not None and qid is not None:
|
if group is not None and qid is not None:
|
||||||
@ -673,8 +675,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
qid=None,
|
qid=None,
|
||||||
label_lower_bound=None,
|
label_lower_bound=None,
|
||||||
label_upper_bound=None,
|
label_upper_bound=None,
|
||||||
feature_names=None,
|
feature_names: Optional[List[str]] = None,
|
||||||
feature_types=None,
|
feature_types: Optional[List[str]] = None,
|
||||||
feature_weights=None
|
feature_weights=None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set meta info for DMatrix. See doc string for :py:obj:`xgboost.DMatrix`."""
|
"""Set meta info for DMatrix. See doc string for :py:obj:`xgboost.DMatrix`."""
|
||||||
@ -945,7 +947,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def feature_names(self) -> List[str]:
|
def feature_names(self) -> Optional[List[str]]:
|
||||||
"""Get feature names (column labels).
|
"""Get feature names (column labels).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -1033,17 +1035,21 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
@feature_types.setter
|
@feature_types.setter
|
||||||
def feature_types(self, feature_types: Optional[Union[List[Any], Any]]) -> None:
|
def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None:
|
||||||
"""Set feature types (column types).
|
"""Set feature types (column types).
|
||||||
|
|
||||||
This is for displaying the results and unrelated
|
This is for displaying the results and categorical data support. See doc string
|
||||||
to the learning process.
|
of :py:obj:`xgboost.DMatrix` for details.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
feature_types : list or None
|
feature_types : list or None
|
||||||
Labels for features. None will reset existing feature names
|
Labels for features. None will reset existing feature names
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# For compatibility reason this function wraps single str input into a list. But
|
||||||
|
# we should not promote such usage since other than visualization, the field is
|
||||||
|
# also used for specifying categorical data type.
|
||||||
if feature_types is not None:
|
if feature_types is not None:
|
||||||
if not isinstance(feature_types, (list, str)):
|
if not isinstance(feature_types, (list, str)):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -2461,8 +2467,13 @@ class Booster(object):
|
|||||||
|
|
||||||
raise ValueError(msg.format(self.feature_names, data.feature_names))
|
raise ValueError(msg.format(self.feature_names, data.feature_names))
|
||||||
|
|
||||||
def get_split_value_histogram(self, feature, fmap='', bins=None,
|
def get_split_value_histogram(
|
||||||
as_pandas=True):
|
self,
|
||||||
|
feature: str,
|
||||||
|
fmap: Union[os.PathLike, str] = '',
|
||||||
|
bins: Optional[int] = None,
|
||||||
|
as_pandas: bool = True
|
||||||
|
):
|
||||||
"""Get split value histogram of a feature
|
"""Get split value histogram of a feature
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -2510,7 +2521,7 @@ class Booster(object):
|
|||||||
except (ValueError, AttributeError, TypeError):
|
except (ValueError, AttributeError, TypeError):
|
||||||
# None.index: attr err, None[0]: type err, fn.index(-1): value err
|
# None.index: attr err, None[0]: type err, fn.index(-1): value err
|
||||||
feature_t = None
|
feature_t = None
|
||||||
if feature_t == "categorical":
|
if feature_t == "c": # categorical
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Split value historgam doesn't support categorical split."
|
"Split value historgam doesn't support categorical split."
|
||||||
)
|
)
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import ctypes
|
|||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
import os
|
import os
|
||||||
from typing import Any, Tuple, Callable
|
from typing import Any, Tuple, Callable, Optional, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -16,6 +16,8 @@ from .compat import lazy_isinstance
|
|||||||
|
|
||||||
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
CAT_T = "c"
|
||||||
|
|
||||||
|
|
||||||
def _warn_unused_missing(data, missing):
|
def _warn_unused_missing(data, missing):
|
||||||
if (missing is not None) and (not np.isnan(missing)):
|
if (missing is not None) and (not np.isnan(missing)):
|
||||||
@ -57,7 +59,13 @@ def _array_interface(data: np.ndarray) -> bytes:
|
|||||||
return interface_str
|
return interface_str
|
||||||
|
|
||||||
|
|
||||||
def _from_scipy_csr(data, missing, nthread, feature_names, feature_types):
|
def _from_scipy_csr(
|
||||||
|
data,
|
||||||
|
missing,
|
||||||
|
nthread,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
):
|
||||||
"""Initialize data from a CSR matrix."""
|
"""Initialize data from a CSR matrix."""
|
||||||
if len(data.indices) != len(data.data):
|
if len(data.indices) != len(data.data):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -91,7 +99,12 @@ def _is_scipy_csc(data):
|
|||||||
return isinstance(data, scipy.sparse.csc_matrix)
|
return isinstance(data, scipy.sparse.csc_matrix)
|
||||||
|
|
||||||
|
|
||||||
def _from_scipy_csc(data, missing, feature_names, feature_types):
|
def _from_scipy_csc(
|
||||||
|
data,
|
||||||
|
missing,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
):
|
||||||
if len(data.indices) != len(data.data):
|
if len(data.indices) != len(data.data):
|
||||||
raise ValueError('length mismatch: {} vs {}'.format(
|
raise ValueError('length mismatch: {} vs {}'.format(
|
||||||
len(data.indices), len(data.data)))
|
len(data.indices), len(data.data)))
|
||||||
@ -142,7 +155,13 @@ def _maybe_np_slice(data, dtype):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _from_numpy_array(data, missing, nthread, feature_names, feature_types):
|
def _from_numpy_array(
|
||||||
|
data,
|
||||||
|
missing,
|
||||||
|
nthread,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
):
|
||||||
"""Initialize data from a 2-D numpy matrix.
|
"""Initialize data from a 2-D numpy matrix.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -199,9 +218,14 @@ _pandas_dtype_mapper = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _transform_pandas_df(data, enable_categorical,
|
def _transform_pandas_df(
|
||||||
feature_names=None, feature_types=None,
|
data,
|
||||||
meta=None, meta_type=None):
|
enable_categorical,
|
||||||
|
feature_names: Optional[List[str]] = None,
|
||||||
|
feature_types: Optional[List[str]] = None,
|
||||||
|
meta=None,
|
||||||
|
meta_type=None,
|
||||||
|
):
|
||||||
from pandas import MultiIndex, Int64Index, RangeIndex
|
from pandas import MultiIndex, Int64Index, RangeIndex
|
||||||
from pandas.api.types import is_sparse, is_categorical_dtype
|
from pandas.api.types import is_sparse, is_categorical_dtype
|
||||||
|
|
||||||
@ -236,7 +260,7 @@ def _transform_pandas_df(data, enable_categorical,
|
|||||||
feature_types.append(_pandas_dtype_mapper[
|
feature_types.append(_pandas_dtype_mapper[
|
||||||
dtype.subtype.name])
|
dtype.subtype.name])
|
||||||
elif is_categorical_dtype(dtype) and enable_categorical:
|
elif is_categorical_dtype(dtype) and enable_categorical:
|
||||||
feature_types.append('categorical')
|
feature_types.append(CAT_T)
|
||||||
else:
|
else:
|
||||||
feature_types.append(_pandas_dtype_mapper[dtype.name])
|
feature_types.append(_pandas_dtype_mapper[dtype.name])
|
||||||
|
|
||||||
@ -253,8 +277,14 @@ def _transform_pandas_df(data, enable_categorical,
|
|||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
def _from_pandas_df(data, enable_categorical, missing, nthread,
|
def _from_pandas_df(
|
||||||
feature_names, feature_types):
|
data,
|
||||||
|
enable_categorical: bool,
|
||||||
|
missing,
|
||||||
|
nthread,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
):
|
||||||
data, feature_names, feature_types = _transform_pandas_df(
|
data, feature_names, feature_types = _transform_pandas_df(
|
||||||
data, enable_categorical, feature_names, feature_types)
|
data, enable_categorical, feature_names, feature_types)
|
||||||
return _from_numpy_array(data, missing, nthread, feature_names,
|
return _from_numpy_array(data, missing, nthread, feature_names,
|
||||||
@ -277,9 +307,16 @@ def _is_modin_series(data):
|
|||||||
return isinstance(data, pd.Series)
|
return isinstance(data, pd.Series)
|
||||||
|
|
||||||
|
|
||||||
def _from_pandas_series(data, missing, nthread, feature_types, feature_names):
|
def _from_pandas_series(
|
||||||
return _from_numpy_array(data.values.astype('float'), missing, nthread,
|
data,
|
||||||
feature_names, feature_types)
|
missing,
|
||||||
|
nthread,
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
):
|
||||||
|
return _from_numpy_array(
|
||||||
|
data.values.astype("float"), missing, nthread, feature_names, feature_types
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_dt_df(data):
|
def _is_dt_df(data):
|
||||||
@ -291,8 +328,13 @@ _dt_type_mapper = {'bool': 'bool', 'int': 'int', 'real': 'float'}
|
|||||||
_dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
_dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
||||||
|
|
||||||
|
|
||||||
def _transform_dt_df(data, feature_names, feature_types, meta=None,
|
def _transform_dt_df(
|
||||||
meta_type=None):
|
data,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
meta=None,
|
||||||
|
meta_type=None,
|
||||||
|
):
|
||||||
"""Validate feature names and types if data table"""
|
"""Validate feature names and types if data table"""
|
||||||
if meta and data.shape[1] > 1:
|
if meta and data.shape[1] > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -325,7 +367,16 @@ def _transform_dt_df(data, feature_names, feature_types, meta=None,
|
|||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
def _from_dt_df(data, missing, nthread, feature_names, feature_types):
|
def _from_dt_df(
|
||||||
|
data,
|
||||||
|
missing,
|
||||||
|
nthread,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
enable_categorical: bool,
|
||||||
|
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
|
||||||
|
if enable_categorical:
|
||||||
|
raise ValueError("categorical data in datatable is not supported yet.")
|
||||||
data, feature_names, feature_types = _transform_dt_df(
|
data, feature_names, feature_types = _transform_dt_df(
|
||||||
data, feature_names, feature_types, None, None)
|
data, feature_names, feature_types, None, None)
|
||||||
|
|
||||||
@ -368,7 +419,7 @@ def _is_cudf_df(data):
|
|||||||
return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame)
|
return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame)
|
||||||
|
|
||||||
|
|
||||||
def _cudf_array_interfaces(data) -> Tuple[list, list]:
|
def _cudf_array_interfaces(data) -> Tuple[list, bytes]:
|
||||||
"""Extract CuDF __cuda_array_interface__. This is special as it returns a new list of
|
"""Extract CuDF __cuda_array_interface__. This is special as it returns a new list of
|
||||||
data and a list of array interfaces. The data is list of categorical codes that
|
data and a list of array interfaces. The data is list of categorical codes that
|
||||||
caller can safely ignore, but have to keep their reference alive until usage of array
|
caller can safely ignore, but have to keep their reference alive until usage of array
|
||||||
@ -395,7 +446,12 @@ def _cudf_array_interfaces(data) -> Tuple[list, list]:
|
|||||||
return cat_codes, interfaces_str
|
return cat_codes, interfaces_str
|
||||||
|
|
||||||
|
|
||||||
def _transform_cudf_df(data, feature_names, feature_types, enable_categorical):
|
def _transform_cudf_df(
|
||||||
|
data,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
enable_categorical: bool,
|
||||||
|
):
|
||||||
from cudf.utils.dtypes import is_categorical_dtype
|
from cudf.utils.dtypes import is_categorical_dtype
|
||||||
|
|
||||||
if feature_names is None:
|
if feature_names is None:
|
||||||
@ -413,14 +469,19 @@ def _transform_cudf_df(data, feature_names, feature_types, enable_categorical):
|
|||||||
dtypes = data.dtypes
|
dtypes = data.dtypes
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
if is_categorical_dtype(dtype) and enable_categorical:
|
if is_categorical_dtype(dtype) and enable_categorical:
|
||||||
feature_types.append("categorical")
|
feature_types.append(CAT_T)
|
||||||
else:
|
else:
|
||||||
feature_types.append(_pandas_dtype_mapper[dtype.name])
|
feature_types.append(_pandas_dtype_mapper[dtype.name])
|
||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
def _from_cudf_df(
|
def _from_cudf_df(
|
||||||
data, missing, nthread, feature_names, feature_types, enable_categorical
|
data,
|
||||||
|
missing,
|
||||||
|
nthread,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
enable_categorical: bool,
|
||||||
) -> Tuple[ctypes.c_void_p, Any, Any]:
|
) -> Tuple[ctypes.c_void_p, Any, Any]:
|
||||||
data, feature_names, feature_types = _transform_cudf_df(
|
data, feature_names, feature_types = _transform_cudf_df(
|
||||||
data, feature_names, feature_types, enable_categorical
|
data, feature_names, feature_types, enable_categorical
|
||||||
@ -464,7 +525,13 @@ def _transform_cupy_array(data):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _from_cupy_array(data, missing, nthread, feature_names, feature_types):
|
def _from_cupy_array(
|
||||||
|
data,
|
||||||
|
missing,
|
||||||
|
nthread,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
):
|
||||||
"""Initialize DMatrix from cupy ndarray."""
|
"""Initialize DMatrix from cupy ndarray."""
|
||||||
data = _transform_cupy_array(data)
|
data = _transform_cupy_array(data)
|
||||||
interface_str = _cuda_array_interface(data)
|
interface_str = _cuda_array_interface(data)
|
||||||
@ -505,7 +572,13 @@ def _transform_dlpack(data):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _from_dlpack(data, missing, nthread, feature_names, feature_types):
|
def _from_dlpack(
|
||||||
|
data,
|
||||||
|
missing,
|
||||||
|
nthread,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
):
|
||||||
data = _transform_dlpack(data)
|
data = _transform_dlpack(data)
|
||||||
return _from_cupy_array(data, missing, nthread, feature_names,
|
return _from_cupy_array(data, missing, nthread, feature_names,
|
||||||
feature_types)
|
feature_types)
|
||||||
@ -515,7 +588,12 @@ def _is_uri(data):
|
|||||||
return isinstance(data, (str, os.PathLike))
|
return isinstance(data, (str, os.PathLike))
|
||||||
|
|
||||||
|
|
||||||
def _from_uri(data, missing, feature_names, feature_types):
|
def _from_uri(
|
||||||
|
data,
|
||||||
|
missing,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
):
|
||||||
_warn_unused_missing(data, missing)
|
_warn_unused_missing(data, missing)
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
data = os.fspath(os.path.expanduser(data))
|
data = os.fspath(os.path.expanduser(data))
|
||||||
@ -529,7 +607,13 @@ def _is_list(data):
|
|||||||
return isinstance(data, list)
|
return isinstance(data, list)
|
||||||
|
|
||||||
|
|
||||||
def _from_list(data, missing, n_threads, feature_names, feature_types):
|
def _from_list(
|
||||||
|
data,
|
||||||
|
missing,
|
||||||
|
n_threads,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
):
|
||||||
array = np.array(data)
|
array = np.array(data)
|
||||||
_check_data_shape(data)
|
_check_data_shape(data)
|
||||||
return _from_numpy_array(array, missing, n_threads, feature_names, feature_types)
|
return _from_numpy_array(array, missing, n_threads, feature_names, feature_types)
|
||||||
@ -539,7 +623,13 @@ def _is_tuple(data):
|
|||||||
return isinstance(data, tuple)
|
return isinstance(data, tuple)
|
||||||
|
|
||||||
|
|
||||||
def _from_tuple(data, missing, n_threads, feature_names, feature_types):
|
def _from_tuple(
|
||||||
|
data,
|
||||||
|
missing,
|
||||||
|
n_threads,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
):
|
||||||
return _from_list(data, missing, n_threads, feature_names, feature_types)
|
return _from_list(data, missing, n_threads, feature_names, feature_types)
|
||||||
|
|
||||||
|
|
||||||
@ -569,9 +659,14 @@ def _convert_unknown_data(data):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def dispatch_data_backend(data, missing, threads,
|
def dispatch_data_backend(
|
||||||
feature_names, feature_types,
|
data,
|
||||||
enable_categorical=False):
|
missing,
|
||||||
|
threads,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
enable_categorical: bool = False,
|
||||||
|
):
|
||||||
'''Dispatch data for DMatrix.'''
|
'''Dispatch data for DMatrix.'''
|
||||||
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
if not _is_cudf_ser(data) and not _is_pandas_series(data):
|
||||||
_check_data_shape(data)
|
_check_data_shape(data)
|
||||||
@ -580,7 +675,9 @@ def dispatch_data_backend(data, missing, threads,
|
|||||||
if _is_scipy_csc(data):
|
if _is_scipy_csc(data):
|
||||||
return _from_scipy_csc(data, missing, feature_names, feature_types)
|
return _from_scipy_csc(data, missing, feature_names, feature_types)
|
||||||
if _is_scipy_coo(data):
|
if _is_scipy_coo(data):
|
||||||
return _from_scipy_csr(data.tocsr(), missing, threads, feature_names, feature_types)
|
return _from_scipy_csr(
|
||||||
|
data.tocsr(), missing, threads, feature_names, feature_types
|
||||||
|
)
|
||||||
if _is_numpy_array(data):
|
if _is_numpy_array(data):
|
||||||
return _from_numpy_array(data, missing, threads, feature_names,
|
return _from_numpy_array(data, missing, threads, feature_names,
|
||||||
feature_types)
|
feature_types)
|
||||||
@ -612,8 +709,9 @@ def dispatch_data_backend(data, missing, threads,
|
|||||||
feature_types)
|
feature_types)
|
||||||
if _is_dt_df(data):
|
if _is_dt_df(data):
|
||||||
_warn_unused_missing(data, missing)
|
_warn_unused_missing(data, missing)
|
||||||
return _from_dt_df(data, missing, threads, feature_names,
|
return _from_dt_df(
|
||||||
feature_types)
|
data, missing, threads, feature_names, feature_types, enable_categorical
|
||||||
|
)
|
||||||
if _is_modin_df(data):
|
if _is_modin_df(data):
|
||||||
return _from_pandas_df(data, enable_categorical, missing, threads,
|
return _from_pandas_df(data, enable_categorical, missing, threads,
|
||||||
feature_names, feature_types)
|
feature_names, feature_types)
|
||||||
@ -791,7 +889,12 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
|||||||
self.it = 0
|
self.it = 0
|
||||||
|
|
||||||
|
|
||||||
def _proxy_transform(data, feature_names, feature_types, enable_categorical):
|
def _proxy_transform(
|
||||||
|
data,
|
||||||
|
feature_names: Optional[List[str]],
|
||||||
|
feature_types: Optional[List[str]],
|
||||||
|
enable_categorical: bool,
|
||||||
|
):
|
||||||
if _is_cudf_df(data) or _is_cudf_ser(data):
|
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||||
return _transform_cudf_df(
|
return _transform_cudf_df(
|
||||||
data, feature_names, feature_types, enable_categorical
|
data, feature_names, feature_types, enable_categorical
|
||||||
|
|||||||
@ -174,8 +174,7 @@ __model_doc = f'''
|
|||||||
.. versionadded:: 1.5.0
|
.. versionadded:: 1.5.0
|
||||||
|
|
||||||
Experimental support for categorical data. Do not set to true unless you are
|
Experimental support for categorical data. Do not set to true unless you are
|
||||||
interested in development. Only valid when `gpu_hist` and pandas dataframe are
|
interested in development. Only valid when `gpu_hist` and dataframe are used.
|
||||||
used.
|
|
||||||
|
|
||||||
kwargs : dict, optional
|
kwargs : dict, optional
|
||||||
Keyword arguments for XGBoost Booster object. Full documentation of
|
Keyword arguments for XGBoost Booster object. Full documentation of
|
||||||
|
|||||||
@ -200,10 +200,10 @@ void LoadFeatureType(std::vector<std::string>const& type_names, std::vector<Feat
|
|||||||
types->emplace_back(FeatureType::kNumerical);
|
types->emplace_back(FeatureType::kNumerical);
|
||||||
} else if (elem == "q") {
|
} else if (elem == "q") {
|
||||||
types->emplace_back(FeatureType::kNumerical);
|
types->emplace_back(FeatureType::kNumerical);
|
||||||
} else if (elem == "categorical") {
|
} else if (elem == "c") {
|
||||||
types->emplace_back(FeatureType::kCategorical);
|
types->emplace_back(FeatureType::kCategorical);
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "All feature_types must be one of {int, float, i, q, categorical}.";
|
LOG(FATAL) << "All feature_types must be one of {int, float, i, q, c}.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -285,7 +285,7 @@ void TestCategoricalTreeDump(std::string format, std::string sep) {
|
|||||||
pos = str.find(cond_str, pos + 1);
|
pos = str.find(cond_str, pos + 1);
|
||||||
ASSERT_NE(pos, std::string::npos);
|
ASSERT_NE(pos, std::string::npos);
|
||||||
|
|
||||||
fmap.PushBack(0, "feat_0", "categorical");
|
fmap.PushBack(0, "feat_0", "c");
|
||||||
fmap.PushBack(1, "feat_1", "q");
|
fmap.PushBack(1, "feat_1", "q");
|
||||||
fmap.PushBack(2, "feat_2", "int");
|
fmap.PushBack(2, "feat_2", "int");
|
||||||
|
|
||||||
|
|||||||
@ -172,7 +172,7 @@ Arrow specification.'''
|
|||||||
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
|
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_categorical(self):
|
def test_cudf_categorical(self):
|
||||||
import cudf
|
import cudf
|
||||||
_X, _y = tm.make_categorical(100, 30, 17, False)
|
_X, _y = tm.make_categorical(100, 30, 17, False)
|
||||||
X = cudf.from_pandas(_X)
|
X = cudf.from_pandas(_X)
|
||||||
@ -180,11 +180,11 @@ Arrow specification.'''
|
|||||||
|
|
||||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||||
assert len(Xy.feature_types) == X.shape[1]
|
assert len(Xy.feature_types) == X.shape[1]
|
||||||
assert all(t == "categorical" for t in Xy.feature_types)
|
assert all(t == "c" for t in Xy.feature_types)
|
||||||
|
|
||||||
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
|
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
|
||||||
assert len(Xy.feature_types) == X.shape[1]
|
assert len(Xy.feature_types) == X.shape[1]
|
||||||
assert all(t == "categorical" for t in Xy.feature_types)
|
assert all(t == "c" for t in Xy.feature_types)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
|
|||||||
@ -169,6 +169,19 @@ Arrow specification.'''
|
|||||||
X = cp.random.random((n, 2))
|
X = cp.random.random((n, 2))
|
||||||
xgb.DMatrix(X.toDlpack())
|
xgb.DMatrix(X.toDlpack())
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
|
def test_cupy_categorical(self):
|
||||||
|
import cupy as cp
|
||||||
|
n_features = 10
|
||||||
|
X, y = tm.make_categorical(10, n_features, n_categories=4, onehot=False)
|
||||||
|
X = cp.asarray(X.values.astype(cp.float32))
|
||||||
|
y = cp.array(y)
|
||||||
|
feature_types = ['c'] * n_features
|
||||||
|
|
||||||
|
assert isinstance(X, cp.ndarray)
|
||||||
|
Xy = xgb.DMatrix(X, y, feature_types=feature_types)
|
||||||
|
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_dlpack_device_dmat(self):
|
def test_dlpack_device_dmat(self):
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|||||||
@ -339,3 +339,44 @@ class TestDMatrix:
|
|||||||
Xy = xgb.DMatrix(X, y)
|
Xy = xgb.DMatrix(X, y)
|
||||||
assert Xy.num_row() == 10
|
assert Xy.num_row() == 10
|
||||||
assert Xy.num_col() == 10
|
assert Xy.num_col() == 10
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
|
def test_np_categorical(self):
|
||||||
|
n_features = 10
|
||||||
|
X, y = tm.make_categorical(10, n_features, n_categories=4, onehot=False)
|
||||||
|
X = X.values.astype(np.float32)
|
||||||
|
feature_types = ['c'] * n_features
|
||||||
|
|
||||||
|
assert isinstance(X, np.ndarray)
|
||||||
|
Xy = xgb.DMatrix(X, y, feature_types=feature_types)
|
||||||
|
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
|
||||||
|
|
||||||
|
def test_scipy_categorical(self):
|
||||||
|
from scipy import sparse
|
||||||
|
n_features = 10
|
||||||
|
X, y = tm.make_categorical(10, n_features, n_categories=4, onehot=False)
|
||||||
|
X = X.values.astype(np.float32)
|
||||||
|
feature_types = ['c'] * n_features
|
||||||
|
|
||||||
|
X[1, 3] = np.NAN
|
||||||
|
X[2, 4] = np.NAN
|
||||||
|
X = sparse.csr_matrix(X)
|
||||||
|
|
||||||
|
Xy = xgb.DMatrix(X, y, feature_types=feature_types)
|
||||||
|
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
|
||||||
|
|
||||||
|
X = sparse.csc_matrix(X)
|
||||||
|
|
||||||
|
Xy = xgb.DMatrix(X, y, feature_types=feature_types)
|
||||||
|
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
|
||||||
|
|
||||||
|
X = sparse.coo_matrix(X)
|
||||||
|
|
||||||
|
Xy = xgb.DMatrix(X, y, feature_types=feature_types)
|
||||||
|
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
|
||||||
|
|
||||||
|
def test_uri_categorical(self):
|
||||||
|
path = os.path.join(dpath, 'agaricus.txt.train')
|
||||||
|
feature_types = ["q"] * 5 + ["c"] + ["q"] * 120
|
||||||
|
Xy = xgb.DMatrix(path + "?indexing_mode=1", feature_types=feature_types)
|
||||||
|
np.testing.assert_equal(np.array(Xy.feature_types), np.array(feature_types))
|
||||||
|
|||||||
@ -128,7 +128,7 @@ class TestPandas:
|
|||||||
X = pd.DataFrame({'f0': X})
|
X = pd.DataFrame({'f0': X})
|
||||||
y = rng.randn(rows)
|
y = rng.randn(rows)
|
||||||
m = xgb.DMatrix(X, y, enable_categorical=True)
|
m = xgb.DMatrix(X, y, enable_categorical=True)
|
||||||
assert m.feature_types[0] == 'categorical'
|
assert m.feature_types[0] == 'c'
|
||||||
|
|
||||||
def test_pandas_sparse(self):
|
def test_pandas_sparse(self):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user