Categorical data support for cuDF. (#7042)

* Add support in DMatrix.
* Add support in DQM, except for iterator.
This commit is contained in:
Jiaming Yuan 2021-06-17 13:54:33 +08:00 committed by GitHub
parent 5c2d7a18c9
commit d9799b09d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 129 additions and 106 deletions

View File

@ -231,17 +231,6 @@ def _numpy2ctypes_type(dtype):
return _NUMPY_TO_CTYPES_MAPPING[dtype]
def _array_interface(data: np.ndarray) -> bytes:
assert (
data.dtype.hasobject is False
), "Input data contains `object` dtype. Expecting numeric data."
interface = data.__array_interface__
if "mask" in interface:
interface["mask"] = interface["mask"].__array_interface__
interface_str = bytes(json.dumps(interface), "utf-8")
return interface_str
def _cuda_array_interface(data) -> bytes:
assert (
data.dtype.hasobject is False
@ -353,11 +342,17 @@ class DataIter:
if self.exception is not None:
return 0
def data_handle(data, feature_names=None, feature_types=None, **kwargs):
def data_handle(
data,
feature_names=None,
feature_types=None,
enable_categorical=False,
**kwargs
):
from .data import dispatch_device_quantile_dmatrix_set_data
from .data import _device_quantile_transform
data, feature_names, feature_types = _device_quantile_transform(
data, feature_names, feature_types
data, feature_names, feature_types, enable_categorical,
)
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
self.proxy.set_info(
@ -1023,7 +1018,7 @@ class _ProxyDMatrix(DMatrix):
def _set_data_from_cuda_columnar(self, data):
'''Set data from CUDA columnar format.1'''
from .data import _cudf_array_interfaces
interfaces_str = _cudf_array_interfaces(data)
_, interfaces_str = _cudf_array_interfaces(data)
_check_call(
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
self.handle,
@ -1076,10 +1071,6 @@ class DeviceQuantileDMatrix(DMatrix):
self.handle = data
return
if enable_categorical:
raise NotImplementedError(
'categorical support is not enabled on DeviceQuantileDMatrix.'
)
if qid is not None and group is not None:
raise ValueError(
'Only one of the eval_qid or eval_group for each evaluation '
@ -1098,9 +1089,10 @@ class DeviceQuantileDMatrix(DMatrix):
feature_weights=feature_weights,
feature_names=feature_names,
feature_types=feature_types,
enable_categorical=enable_categorical,
)
def _init(self, data, feature_names, feature_types, **meta):
def _init(self, data, enable_categorical, **meta):
from .data import (
_is_dlpack,
_transform_dlpack,
@ -1114,9 +1106,13 @@ class DeviceQuantileDMatrix(DMatrix):
data = _transform_dlpack(data)
if _is_iter(data):
it = data
if enable_categorical:
raise NotImplementedError(
"categorical support is not enabled on data iterator."
)
else:
it = SingleBatchInternalIter(
data, **meta, feature_names=feature_names, feature_types=feature_types
data=data, enable_categorical=enable_categorical, **meta
)
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
@ -1920,6 +1916,7 @@ class Booster(object):
f"got {data.shape[1]}"
)
from .data import _array_interface
if isinstance(data, np.ndarray):
from .data import _ensure_np_dtype
data, _ = _ensure_np_dtype(data, data.dtype)
@ -1974,7 +1971,7 @@ class Booster(object):
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
from .data import _cudf_array_interfaces
interfaces_str = _cudf_array_interfaces(data)
_, interfaces_str = _cudf_array_interfaces(data)
_check_call(
_LIB.XGBoosterPredictFromCudaColumnar(
self.handle,

View File

@ -5,12 +5,12 @@ import ctypes
import json
import warnings
import os
from typing import Any
from typing import Any, Tuple
import numpy as np
from .core import c_array, _LIB, _check_call, c_str
from .core import _array_interface, _cuda_array_interface
from .core import _cuda_array_interface
from .core import DataIter, _ProxyDMatrix, DMatrix
from .compat import lazy_isinstance
@ -41,6 +41,17 @@ def _is_scipy_csr(data):
return isinstance(data, scipy.sparse.csr_matrix)
def _array_interface(data: np.ndarray) -> bytes:
assert (
data.dtype.hasobject is False
), "Input data contains `object` dtype. Expecting numeric data."
interface = data.__array_interface__
if "mask" in interface:
interface["mask"] = interface["mask"].__array_interface__
interface_str = bytes(json.dumps(interface), "utf-8")
return interface_str
def _from_scipy_csr(data, missing, nthread, feature_names, feature_types):
"""Initialize data from a CSR matrix."""
if len(data.indices) != len(data.data):
@ -179,7 +190,7 @@ _pandas_dtype_mapper = {
'float16': 'float',
'float32': 'float',
'float64': 'float',
'bool': 'i'
'bool': 'i',
}
@ -349,54 +360,73 @@ def _is_cudf_df(data):
return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame)
def _cudf_array_interfaces(data):
'''Extract CuDF __cuda_array_interface__'''
def _cudf_array_interfaces(data) -> Tuple[list, list]:
"""Extract CuDF __cuda_array_interface__. This is special as it returns a new list of
data and a list of array interfaces. The data is list of categorical codes that
caller can safely ignore, but have to keep their reference alive until usage of array
interface is finished.
"""
from cudf.utils.dtypes import is_categorical_dtype
cat_codes = []
interfaces = []
if _is_cudf_ser(data):
interfaces.append(data.__cuda_array_interface__)
else:
for col in data:
interface = data[col].__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
if is_categorical_dtype(data[col].dtype):
codes = data[col].cat.codes
interface = codes.__cuda_array_interface__
cat_codes.append(codes)
else:
interface = data[col].__cuda_array_interface__
if "mask" in interface:
interface["mask"] = interface["mask"].__cuda_array_interface__
interfaces.append(interface)
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
return interfaces_str
interfaces_str = bytes(json.dumps(interfaces, indent=2), "utf-8")
return cat_codes, interfaces_str
def _transform_cudf_df(data, feature_names, feature_types):
def _transform_cudf_df(data, feature_names, feature_types, enable_categorical):
from cudf.utils.dtypes import is_categorical_dtype
if feature_names is None:
if _is_cudf_ser(data):
feature_names = [data.name]
elif lazy_isinstance(
data.columns, 'cudf.core.multiindex', 'MultiIndex'):
feature_names = [
' '.join([str(x) for x in i])
for i in data.columns
]
elif lazy_isinstance(data.columns, "cudf.core.multiindex", "MultiIndex"):
feature_names = [" ".join([str(x) for x in i]) for i in data.columns]
else:
feature_names = data.columns.format()
if feature_types is None:
feature_types = []
if _is_cudf_ser(data):
dtypes = [data.dtype]
else:
dtypes = data.dtypes
feature_types = [_pandas_dtype_mapper[d.name]
for d in dtypes]
for dtype in dtypes:
if is_categorical_dtype(dtype) and enable_categorical:
feature_types.append("categorical")
else:
feature_types.append(_pandas_dtype_mapper[dtype.name])
return data, feature_names, feature_types
def _from_cudf_df(data, missing, nthread, feature_names, feature_types):
def _from_cudf_df(
data, missing, nthread, feature_names, feature_types, enable_categorical
):
data, feature_names, feature_types = _transform_cudf_df(
data, feature_names, feature_types)
interfaces_str = _cudf_array_interfaces(data)
data, feature_names, feature_types, enable_categorical
)
_, interfaces_str = _cudf_array_interfaces(data)
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(missing),
ctypes.c_int(nthread),
ctypes.byref(handle)))
ctypes.byref(handle),
)
)
return handle, feature_names, feature_types
@ -554,12 +584,10 @@ def dispatch_data_backend(data, missing, threads,
if _is_pandas_series(data):
return _from_pandas_series(data, missing, threads, feature_names,
feature_types)
if _is_cudf_df(data):
return _from_cudf_df(data, missing, threads, feature_names,
feature_types)
if _is_cudf_ser(data):
return _from_cudf_df(data, missing, threads, feature_names,
feature_types)
if _is_cudf_df(data) or _is_cudf_ser(data):
return _from_cudf_df(
data, missing, threads, feature_names, feature_types, enable_categorical
)
if _is_cupy_array(data):
return _from_cupy_array(data, missing, threads, feature_names,
feature_types)
@ -731,30 +759,8 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
area for meta info.
'''
def __init__(
self, data,
label,
weight,
base_margin,
group,
qid,
label_lower_bound,
label_upper_bound,
feature_weights,
feature_names,
feature_types
):
self.data = data
self.label = label
self.weight = weight
self.base_margin = base_margin
self.group = group
self.qid = qid
self.label_lower_bound = label_lower_bound
self.label_upper_bound = label_upper_bound
self.feature_weights = feature_weights
self.feature_names = feature_names
self.feature_types = feature_types
def __init__(self, **kwargs):
self.kwargs = kwargs
self.it = 0 # pylint: disable=invalid-name
super().__init__()
@ -762,33 +768,24 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
if self.it == 1:
return 0
self.it += 1
input_data(data=self.data, label=self.label,
weight=self.weight, base_margin=self.base_margin,
group=self.group,
qid=self.qid,
label_lower_bound=self.label_lower_bound,
label_upper_bound=self.label_upper_bound,
feature_weights=self.feature_weights,
feature_names=self.feature_names,
feature_types=self.feature_types)
input_data(**self.kwargs)
return 1
def reset(self):
self.it = 0
def _device_quantile_transform(data, feature_names, feature_types):
if _is_cudf_df(data):
return _transform_cudf_df(data, feature_names, feature_types)
if _is_cudf_ser(data):
return _transform_cudf_df(data, feature_names, feature_types)
def _device_quantile_transform(data, feature_names, feature_types, enable_categorical):
if _is_cudf_df(data) or _is_cudf_ser(data):
return _transform_cudf_df(
data, feature_names, feature_types, enable_categorical
)
if _is_cupy_array(data):
data = _transform_cupy_array(data)
return data, feature_names, feature_types
if _is_dlpack(data):
return _transform_dlpack(data), feature_names, feature_types
raise TypeError('Value type is not supported for data iterator:' +
str(type(data)))
raise TypeError("Value type is not supported for data iterator:" + str(type(data)))
def dispatch_device_quantile_dmatrix_set_data(proxy: _ProxyDMatrix, data: Any) -> None:

View File

@ -171,6 +171,21 @@ Arrow specification.'''
def test_cudf_metainfo_device_dmatrix(self):
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
@pytest.mark.skipif(**tm.no_cudf())
def test_categorical(self):
import cudf
_X, _y = tm.make_categorical(100, 30, 17, False)
X = cudf.from_pandas(_X)
y = cudf.from_pandas(_y)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
assert len(Xy.feature_types) == X.shape[1]
assert all(t == "categorical" for t in Xy.feature_types)
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
assert len(Xy.feature_types) == X.shape[1]
assert all(t == "categorical" for t in Xy.feature_types)
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_cupy())

View File

@ -43,22 +43,8 @@ class TestGPUUpdaters:
assert tm.non_increasing(result['train'][dataset.metric])
def run_categorical_basic(self, rows, cols, rounds, cats):
import pandas as pd
rng = np.random.RandomState(1994)
pd_dict = {}
for i in range(cols):
c = rng.randint(low=0, high=cats+1, size=rows)
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
df = pd.DataFrame(pd_dict)
label = df.iloc[:, 0]
for i in range(0, cols-1):
label += df.iloc[:, i]
label += 1
df = df.astype('category')
onehot = pd.get_dummies(df)
cat = df
onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)
by_etl_results = {}
by_builtin_results = {}

View File

@ -234,6 +234,34 @@ def get_mq2008(dpath):
x_valid, y_valid, qid_valid)
@memory.cache
def make_categorical(
n_samples: int, n_features: int, n_categories: int, onehot_enc: bool
):
import pandas as pd
rng = np.random.RandomState(1994)
pd_dict = {}
for i in range(n_features + 1):
c = rng.randint(low=0, high=n_categories + 1, size=n_samples)
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
df = pd.DataFrame(pd_dict)
label = df.iloc[:, 0]
df = df.iloc[:, 1:]
for i in range(0, n_features):
label += df.iloc[:, i]
label += 1
df = df.astype("category")
if onehot_enc:
cat = pd.get_dummies(df)
else:
cat = df
return cat, label
_unweighted_datasets_strategy = strategies.sampled_from(
[TestDataset('boston', get_boston, 'reg:squarederror', 'rmse'),
TestDataset('digits', get_digits, 'multi:softmax', 'mlogloss'),