Fix mixed types with cuDF. (#8280)

This commit is contained in:
Jiaming Yuan 2022-09-29 00:57:52 +08:00 committed by GitHub
parent f835368bcf
commit 6925b222e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 106 additions and 47 deletions

View File

@ -1,49 +1,72 @@
# pylint: disable=too-many-arguments, too-many-branches, invalid-name # pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-lines, too-many-locals # pylint: disable=too-many-lines, too-many-locals
"""Core XGBoost Library.""" """Core XGBoost Library."""
from abc import ABC, abstractmethod
from collections.abc import Mapping
import copy import copy
from typing import List, Optional, Any, Union, Dict, TypeVar
from typing import Callable, Tuple, cast, Sequence, Type, Iterable
import ctypes import ctypes
import json
import os import os
import re import re
import sys import sys
import json
import warnings import warnings
from abc import ABC, abstractmethod
from collections.abc import Mapping
from functools import wraps from functools import wraps
from inspect import signature, Parameter from inspect import Parameter, signature
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .compat import DataFrame, py_str, PANDAS_INSTALLED
from .libpath import find_lib_path
from ._typing import ( from ._typing import (
CStrPptr, _T,
c_bst_ulong, ArrayLike,
BoosterParam,
CFloatPtr,
CNumeric, CNumeric,
DataType,
CNumericPtr, CNumericPtr,
CStrPptr,
CStrPtr, CStrPtr,
CTypeT, CTypeT,
ArrayLike,
CFloatPtr,
NumpyOrCupy,
FeatureInfo,
FeatureTypes,
FeatureNames,
_T,
CupyT, CupyT,
BoosterParam DataType,
FeatureInfo,
FeatureNames,
FeatureTypes,
NumpyOrCupy,
c_bst_ulong,
) )
from .compat import PANDAS_INSTALLED, DataFrame, py_str
from .libpath import find_lib_path
class XGBoostError(ValueError): class XGBoostError(ValueError):
"""Error thrown by xgboost trainer.""" """Error thrown by xgboost trainer."""
@overload
def from_pystr_to_cstr(data: str) -> bytes:
...
@overload
def from_pystr_to_cstr(data: List[str]) -> ctypes.Array:
...
def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, ctypes.Array]: def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, ctypes.Array]:
"""Convert a Python str or list of Python str to C pointer """Convert a Python str or list of Python str to C pointer

View File

@ -3,24 +3,33 @@
'''Data dispatching for DMatrix.''' '''Data dispatching for DMatrix.'''
import ctypes import ctypes
import json import json
import warnings
import os import os
from typing import Any, Tuple, Callable, Optional, List, Union, Iterator, Sequence, cast import warnings
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union, cast
import numpy as np import numpy as np
from .core import c_array, _LIB, _check_call, c_str
from .core import _cuda_array_interface
from .core import DataIter, _ProxyDMatrix, DMatrix
from .compat import lazy_isinstance, DataFrame
from ._typing import ( from ._typing import (
c_bst_ulong,
DataType,
FeatureTypes,
FeatureNames,
NumpyDType,
CupyT, CupyT,
FloatCompatible, PandasDType DataType,
FeatureNames,
FeatureTypes,
FloatCompatible,
NumpyDType,
PandasDType,
c_bst_ulong,
)
from .compat import DataFrame, lazy_isinstance
from .core import (
_LIB,
DataIter,
DMatrix,
_check_call,
_cuda_array_interface,
_ProxyDMatrix,
c_array,
c_str,
from_pystr_to_cstr,
) )
DispatchedDataBackendReturnType = Tuple[ DispatchedDataBackendReturnType = Tuple[
@ -631,10 +640,10 @@ def _is_cudf_df(data: DataType) -> bool:
def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes: def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes:
"""Extract CuDF __cuda_array_interface__. This is special as it returns a new list of """Extract CuDF __cuda_array_interface__. This is special as it returns a new list
data and a list of array interfaces. The data is list of categorical codes that of data and a list of array interfaces. The data is list of categorical codes that
caller can safely ignore, but have to keep their reference alive until usage of array caller can safely ignore, but have to keep their reference alive until usage of
interface is finished. array interface is finished.
""" """
try: try:
@ -643,14 +652,18 @@ def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes:
from cudf.utils.dtypes import is_categorical_dtype from cudf.utils.dtypes import is_categorical_dtype
interfaces = [] interfaces = []
def append(interface: dict) -> None:
if "mask" in interface:
interface["mask"] = interface["mask"].__cuda_array_interface__
interfaces.append(interface)
if _is_cudf_ser(data): if _is_cudf_ser(data):
if is_categorical_dtype(data.dtype): if is_categorical_dtype(data.dtype):
interface = cat_codes[0].__cuda_array_interface__ interface = cat_codes[0].__cuda_array_interface__
else: else:
interface = data.__cuda_array_interface__ interface = data.__cuda_array_interface__
if "mask" in interface: append(interface)
interface["mask"] = interface["mask"].__cuda_array_interface__
interfaces.append(interface)
else: else:
for i, col in enumerate(data): for i, col in enumerate(data):
if is_categorical_dtype(data[col].dtype): if is_categorical_dtype(data[col].dtype):
@ -658,10 +671,8 @@ def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes:
interface = codes.__cuda_array_interface__ interface = codes.__cuda_array_interface__
else: else:
interface = data[col].__cuda_array_interface__ interface = data[col].__cuda_array_interface__
if "mask" in interface: append(interface)
interface["mask"] = interface["mask"].__cuda_array_interface__ interfaces_str = from_pystr_to_cstr(json.dumps(interfaces))
interfaces.append(interface)
interfaces_str = bytes(json.dumps(interfaces, indent=2), "utf-8")
return interfaces_str return interfaces_str
@ -722,9 +733,14 @@ def _transform_cudf_df(
cat_codes.append(codes) cat_codes.append(codes)
else: else:
for col in data: for col in data:
if is_categorical_dtype(data[col].dtype) and enable_categorical: dtype = data[col].dtype
if is_categorical_dtype(dtype) and enable_categorical:
codes = data[col].cat.codes codes = data[col].cat.codes
cat_codes.append(codes) cat_codes.append(codes)
elif is_categorical_dtype(dtype):
raise ValueError(_ENABLE_CAT_ERR)
else:
cat_codes.append([])
return data, cat_codes, feature_names, feature_types return data, cat_codes, feature_names, feature_types

View File

@ -1,6 +1,8 @@
import json
import sys
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import sys
import pytest import pytest
sys.path.append("tests/python") sys.path.append("tests/python")
@ -176,20 +178,38 @@ Arrow specification.'''
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix) _test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_cudf_categorical(self): def test_cudf_categorical(self) -> None:
import cudf import cudf
_X, _y = tm.make_categorical(100, 30, 17, False) n_features = 30
_X, _y = tm.make_categorical(100, n_features, 17, False)
X = cudf.from_pandas(_X) X = cudf.from_pandas(_X)
y = cudf.from_pandas(_y) y = cudf.from_pandas(_y)
Xy = xgb.DMatrix(X, y, enable_categorical=True) Xy = xgb.DMatrix(X, y, enable_categorical=True)
assert Xy.feature_types is not None
assert len(Xy.feature_types) == X.shape[1] assert len(Xy.feature_types) == X.shape[1]
assert all(t == "c" for t in Xy.feature_types) assert all(t == "c" for t in Xy.feature_types)
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True) Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
assert Xy.feature_types is not None
assert len(Xy.feature_types) == X.shape[1] assert len(Xy.feature_types) == X.shape[1]
assert all(t == "c" for t in Xy.feature_types) assert all(t == "c" for t in Xy.feature_types)
# mixed dtypes
X["1"] = X["1"].astype(np.int64)
X["3"] = X["3"].astype(np.int64)
df, cat_codes, _, _ = xgb.data._transform_cudf_df(
X, None, None, enable_categorical=True
)
assert X.shape[1] == n_features
assert len(cat_codes) == X.shape[1]
assert not cat_codes[0]
assert not cat_codes[2]
interfaces_str = xgb.data._cudf_array_interfaces(df, cat_codes)
interfaces = json.loads(interfaces_str)
assert len(interfaces) == X.shape[1]
# test missing value # test missing value
X = cudf.DataFrame({"f0": ["a", "b", np.NaN]}) X = cudf.DataFrame({"f0": ["a", "b", np.NaN]})
X["f0"] = X["f0"].astype("category") X["f0"] = X["f0"].astype("category")
@ -206,7 +226,7 @@ Arrow specification.'''
assert Xy.num_row() == 3 assert Xy.num_row() == 3
assert Xy.num_col() == 1 assert Xy.num_col() == 1
with pytest.raises(ValueError): with pytest.raises(ValueError, match="enable_categorical"):
xgb.DeviceQuantileDMatrix(X, y) xgb.DeviceQuantileDMatrix(X, y)
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True) Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)