Categorical data support for cuDF. (#7042)
* Add support in DMatrix. * Add support in DQM, except for iterator.
This commit is contained in:
parent
5c2d7a18c9
commit
d9799b09d0
@ -231,17 +231,6 @@ def _numpy2ctypes_type(dtype):
|
||||
return _NUMPY_TO_CTYPES_MAPPING[dtype]
|
||||
|
||||
|
||||
def _array_interface(data: np.ndarray) -> bytes:
|
||||
assert (
|
||||
data.dtype.hasobject is False
|
||||
), "Input data contains `object` dtype. Expecting numeric data."
|
||||
interface = data.__array_interface__
|
||||
if "mask" in interface:
|
||||
interface["mask"] = interface["mask"].__array_interface__
|
||||
interface_str = bytes(json.dumps(interface), "utf-8")
|
||||
return interface_str
|
||||
|
||||
|
||||
def _cuda_array_interface(data) -> bytes:
|
||||
assert (
|
||||
data.dtype.hasobject is False
|
||||
@ -353,11 +342,17 @@ class DataIter:
|
||||
if self.exception is not None:
|
||||
return 0
|
||||
|
||||
def data_handle(data, feature_names=None, feature_types=None, **kwargs):
|
||||
def data_handle(
|
||||
data,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
enable_categorical=False,
|
||||
**kwargs
|
||||
):
|
||||
from .data import dispatch_device_quantile_dmatrix_set_data
|
||||
from .data import _device_quantile_transform
|
||||
data, feature_names, feature_types = _device_quantile_transform(
|
||||
data, feature_names, feature_types
|
||||
data, feature_names, feature_types, enable_categorical,
|
||||
)
|
||||
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
||||
self.proxy.set_info(
|
||||
@ -1023,7 +1018,7 @@ class _ProxyDMatrix(DMatrix):
|
||||
def _set_data_from_cuda_columnar(self, data):
|
||||
'''Set data from CUDA columnar format.1'''
|
||||
from .data import _cudf_array_interfaces
|
||||
interfaces_str = _cudf_array_interfaces(data)
|
||||
_, interfaces_str = _cudf_array_interfaces(data)
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
|
||||
self.handle,
|
||||
@ -1076,10 +1071,6 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
self.handle = data
|
||||
return
|
||||
|
||||
if enable_categorical:
|
||||
raise NotImplementedError(
|
||||
'categorical support is not enabled on DeviceQuantileDMatrix.'
|
||||
)
|
||||
if qid is not None and group is not None:
|
||||
raise ValueError(
|
||||
'Only one of the eval_qid or eval_group for each evaluation '
|
||||
@ -1098,9 +1089,10 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
feature_weights=feature_weights,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
|
||||
def _init(self, data, feature_names, feature_types, **meta):
|
||||
def _init(self, data, enable_categorical, **meta):
|
||||
from .data import (
|
||||
_is_dlpack,
|
||||
_transform_dlpack,
|
||||
@ -1114,9 +1106,13 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
data = _transform_dlpack(data)
|
||||
if _is_iter(data):
|
||||
it = data
|
||||
if enable_categorical:
|
||||
raise NotImplementedError(
|
||||
"categorical support is not enabled on data iterator."
|
||||
)
|
||||
else:
|
||||
it = SingleBatchInternalIter(
|
||||
data, **meta, feature_names=feature_names, feature_types=feature_types
|
||||
data=data, enable_categorical=enable_categorical, **meta
|
||||
)
|
||||
|
||||
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
|
||||
@ -1920,6 +1916,7 @@ class Booster(object):
|
||||
f"got {data.shape[1]}"
|
||||
)
|
||||
|
||||
from .data import _array_interface
|
||||
if isinstance(data, np.ndarray):
|
||||
from .data import _ensure_np_dtype
|
||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||
@ -1974,7 +1971,7 @@ class Booster(object):
|
||||
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
|
||||
from .data import _cudf_array_interfaces
|
||||
|
||||
interfaces_str = _cudf_array_interfaces(data)
|
||||
_, interfaces_str = _cudf_array_interfaces(data)
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromCudaColumnar(
|
||||
self.handle,
|
||||
|
||||
@ -5,12 +5,12 @@ import ctypes
|
||||
import json
|
||||
import warnings
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .core import c_array, _LIB, _check_call, c_str
|
||||
from .core import _array_interface, _cuda_array_interface
|
||||
from .core import _cuda_array_interface
|
||||
from .core import DataIter, _ProxyDMatrix, DMatrix
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
@ -41,6 +41,17 @@ def _is_scipy_csr(data):
|
||||
return isinstance(data, scipy.sparse.csr_matrix)
|
||||
|
||||
|
||||
def _array_interface(data: np.ndarray) -> bytes:
|
||||
assert (
|
||||
data.dtype.hasobject is False
|
||||
), "Input data contains `object` dtype. Expecting numeric data."
|
||||
interface = data.__array_interface__
|
||||
if "mask" in interface:
|
||||
interface["mask"] = interface["mask"].__array_interface__
|
||||
interface_str = bytes(json.dumps(interface), "utf-8")
|
||||
return interface_str
|
||||
|
||||
|
||||
def _from_scipy_csr(data, missing, nthread, feature_names, feature_types):
|
||||
"""Initialize data from a CSR matrix."""
|
||||
if len(data.indices) != len(data.data):
|
||||
@ -179,7 +190,7 @@ _pandas_dtype_mapper = {
|
||||
'float16': 'float',
|
||||
'float32': 'float',
|
||||
'float64': 'float',
|
||||
'bool': 'i'
|
||||
'bool': 'i',
|
||||
}
|
||||
|
||||
|
||||
@ -349,54 +360,73 @@ def _is_cudf_df(data):
|
||||
return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame)
|
||||
|
||||
|
||||
def _cudf_array_interfaces(data):
|
||||
'''Extract CuDF __cuda_array_interface__'''
|
||||
def _cudf_array_interfaces(data) -> Tuple[list, list]:
|
||||
"""Extract CuDF __cuda_array_interface__. This is special as it returns a new list of
|
||||
data and a list of array interfaces. The data is list of categorical codes that
|
||||
caller can safely ignore, but have to keep their reference alive until usage of array
|
||||
interface is finished.
|
||||
|
||||
"""
|
||||
from cudf.utils.dtypes import is_categorical_dtype
|
||||
cat_codes = []
|
||||
interfaces = []
|
||||
if _is_cudf_ser(data):
|
||||
interfaces.append(data.__cuda_array_interface__)
|
||||
else:
|
||||
for col in data:
|
||||
interface = data[col].__cuda_array_interface__
|
||||
if 'mask' in interface:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
if is_categorical_dtype(data[col].dtype):
|
||||
codes = data[col].cat.codes
|
||||
interface = codes.__cuda_array_interface__
|
||||
cat_codes.append(codes)
|
||||
else:
|
||||
interface = data[col].__cuda_array_interface__
|
||||
if "mask" in interface:
|
||||
interface["mask"] = interface["mask"].__cuda_array_interface__
|
||||
interfaces.append(interface)
|
||||
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
|
||||
return interfaces_str
|
||||
interfaces_str = bytes(json.dumps(interfaces, indent=2), "utf-8")
|
||||
return cat_codes, interfaces_str
|
||||
|
||||
|
||||
def _transform_cudf_df(data, feature_names, feature_types):
|
||||
def _transform_cudf_df(data, feature_names, feature_types, enable_categorical):
|
||||
from cudf.utils.dtypes import is_categorical_dtype
|
||||
|
||||
if feature_names is None:
|
||||
if _is_cudf_ser(data):
|
||||
feature_names = [data.name]
|
||||
elif lazy_isinstance(
|
||||
data.columns, 'cudf.core.multiindex', 'MultiIndex'):
|
||||
feature_names = [
|
||||
' '.join([str(x) for x in i])
|
||||
for i in data.columns
|
||||
]
|
||||
elif lazy_isinstance(data.columns, "cudf.core.multiindex", "MultiIndex"):
|
||||
feature_names = [" ".join([str(x) for x in i]) for i in data.columns]
|
||||
else:
|
||||
feature_names = data.columns.format()
|
||||
if feature_types is None:
|
||||
feature_types = []
|
||||
if _is_cudf_ser(data):
|
||||
dtypes = [data.dtype]
|
||||
else:
|
||||
dtypes = data.dtypes
|
||||
feature_types = [_pandas_dtype_mapper[d.name]
|
||||
for d in dtypes]
|
||||
for dtype in dtypes:
|
||||
if is_categorical_dtype(dtype) and enable_categorical:
|
||||
feature_types.append("categorical")
|
||||
else:
|
||||
feature_types.append(_pandas_dtype_mapper[dtype.name])
|
||||
return data, feature_names, feature_types
|
||||
|
||||
|
||||
def _from_cudf_df(data, missing, nthread, feature_names, feature_types):
|
||||
def _from_cudf_df(
|
||||
data, missing, nthread, feature_names, feature_types, enable_categorical
|
||||
):
|
||||
data, feature_names, feature_types = _transform_cudf_df(
|
||||
data, feature_names, feature_types)
|
||||
interfaces_str = _cudf_array_interfaces(data)
|
||||
data, feature_names, feature_types, enable_categorical
|
||||
)
|
||||
_, interfaces_str = _cudf_array_interfaces(data)
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
interfaces_str,
|
||||
ctypes.c_float(missing),
|
||||
ctypes.c_int(nthread),
|
||||
ctypes.byref(handle)))
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
)
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
@ -554,12 +584,10 @@ def dispatch_data_backend(data, missing, threads,
|
||||
if _is_pandas_series(data):
|
||||
return _from_pandas_series(data, missing, threads, feature_names,
|
||||
feature_types)
|
||||
if _is_cudf_df(data):
|
||||
return _from_cudf_df(data, missing, threads, feature_names,
|
||||
feature_types)
|
||||
if _is_cudf_ser(data):
|
||||
return _from_cudf_df(data, missing, threads, feature_names,
|
||||
feature_types)
|
||||
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||
return _from_cudf_df(
|
||||
data, missing, threads, feature_names, feature_types, enable_categorical
|
||||
)
|
||||
if _is_cupy_array(data):
|
||||
return _from_cupy_array(data, missing, threads, feature_names,
|
||||
feature_types)
|
||||
@ -731,30 +759,8 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||
area for meta info.
|
||||
|
||||
'''
|
||||
def __init__(
|
||||
self, data,
|
||||
label,
|
||||
weight,
|
||||
base_margin,
|
||||
group,
|
||||
qid,
|
||||
label_lower_bound,
|
||||
label_upper_bound,
|
||||
feature_weights,
|
||||
feature_names,
|
||||
feature_types
|
||||
):
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.weight = weight
|
||||
self.base_margin = base_margin
|
||||
self.group = group
|
||||
self.qid = qid
|
||||
self.label_lower_bound = label_lower_bound
|
||||
self.label_upper_bound = label_upper_bound
|
||||
self.feature_weights = feature_weights
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
def __init__(self, **kwargs):
|
||||
self.kwargs = kwargs
|
||||
self.it = 0 # pylint: disable=invalid-name
|
||||
super().__init__()
|
||||
|
||||
@ -762,33 +768,24 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||
if self.it == 1:
|
||||
return 0
|
||||
self.it += 1
|
||||
input_data(data=self.data, label=self.label,
|
||||
weight=self.weight, base_margin=self.base_margin,
|
||||
group=self.group,
|
||||
qid=self.qid,
|
||||
label_lower_bound=self.label_lower_bound,
|
||||
label_upper_bound=self.label_upper_bound,
|
||||
feature_weights=self.feature_weights,
|
||||
feature_names=self.feature_names,
|
||||
feature_types=self.feature_types)
|
||||
input_data(**self.kwargs)
|
||||
return 1
|
||||
|
||||
def reset(self):
|
||||
self.it = 0
|
||||
|
||||
|
||||
def _device_quantile_transform(data, feature_names, feature_types):
|
||||
if _is_cudf_df(data):
|
||||
return _transform_cudf_df(data, feature_names, feature_types)
|
||||
if _is_cudf_ser(data):
|
||||
return _transform_cudf_df(data, feature_names, feature_types)
|
||||
def _device_quantile_transform(data, feature_names, feature_types, enable_categorical):
|
||||
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||
return _transform_cudf_df(
|
||||
data, feature_names, feature_types, enable_categorical
|
||||
)
|
||||
if _is_cupy_array(data):
|
||||
data = _transform_cupy_array(data)
|
||||
return data, feature_names, feature_types
|
||||
if _is_dlpack(data):
|
||||
return _transform_dlpack(data), feature_names, feature_types
|
||||
raise TypeError('Value type is not supported for data iterator:' +
|
||||
str(type(data)))
|
||||
raise TypeError("Value type is not supported for data iterator:" + str(type(data)))
|
||||
|
||||
|
||||
def dispatch_device_quantile_dmatrix_set_data(proxy: _ProxyDMatrix, data: Any) -> None:
|
||||
|
||||
@ -171,6 +171,21 @@ Arrow specification.'''
|
||||
def test_cudf_metainfo_device_dmatrix(self):
|
||||
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_categorical(self):
|
||||
import cudf
|
||||
_X, _y = tm.make_categorical(100, 30, 17, False)
|
||||
X = cudf.from_pandas(_X)
|
||||
y = cudf.from_pandas(_y)
|
||||
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
assert len(Xy.feature_types) == X.shape[1]
|
||||
assert all(t == "categorical" for t in Xy.feature_types)
|
||||
|
||||
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
|
||||
assert len(Xy.feature_types) == X.shape[1]
|
||||
assert all(t == "categorical" for t in Xy.feature_types)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
|
||||
@ -43,22 +43,8 @@ class TestGPUUpdaters:
|
||||
assert tm.non_increasing(result['train'][dataset.metric])
|
||||
|
||||
def run_categorical_basic(self, rows, cols, rounds, cats):
|
||||
import pandas as pd
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
pd_dict = {}
|
||||
for i in range(cols):
|
||||
c = rng.randint(low=0, high=cats+1, size=rows)
|
||||
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||
|
||||
df = pd.DataFrame(pd_dict)
|
||||
label = df.iloc[:, 0]
|
||||
for i in range(0, cols-1):
|
||||
label += df.iloc[:, i]
|
||||
label += 1
|
||||
df = df.astype('category')
|
||||
onehot = pd.get_dummies(df)
|
||||
cat = df
|
||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
||||
|
||||
by_etl_results = {}
|
||||
by_builtin_results = {}
|
||||
|
||||
@ -234,6 +234,34 @@ def get_mq2008(dpath):
|
||||
x_valid, y_valid, qid_valid)
|
||||
|
||||
|
||||
@memory.cache
|
||||
def make_categorical(
|
||||
n_samples: int, n_features: int, n_categories: int, onehot_enc: bool
|
||||
):
|
||||
import pandas as pd
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
pd_dict = {}
|
||||
for i in range(n_features + 1):
|
||||
c = rng.randint(low=0, high=n_categories + 1, size=n_samples)
|
||||
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||
|
||||
df = pd.DataFrame(pd_dict)
|
||||
label = df.iloc[:, 0]
|
||||
df = df.iloc[:, 1:]
|
||||
for i in range(0, n_features):
|
||||
label += df.iloc[:, i]
|
||||
label += 1
|
||||
|
||||
df = df.astype("category")
|
||||
if onehot_enc:
|
||||
cat = pd.get_dummies(df)
|
||||
else:
|
||||
cat = df
|
||||
return cat, label
|
||||
|
||||
|
||||
_unweighted_datasets_strategy = strategies.sampled_from(
|
||||
[TestDataset('boston', get_boston, 'reg:squarederror', 'rmse'),
|
||||
TestDataset('digits', get_digits, 'multi:softmax', 'mlogloss'),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user