Fix mixed types with cuDF. (#8280)
This commit is contained in:
parent
f835368bcf
commit
6925b222e0
@ -1,49 +1,72 @@
|
|||||||
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
||||||
# pylint: disable=too-many-lines, too-many-locals
|
# pylint: disable=too-many-lines, too-many-locals
|
||||||
"""Core XGBoost Library."""
|
"""Core XGBoost Library."""
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from collections.abc import Mapping
|
|
||||||
import copy
|
import copy
|
||||||
from typing import List, Optional, Any, Union, Dict, TypeVar
|
|
||||||
from typing import Callable, Tuple, cast, Sequence, Type, Iterable
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import json
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Mapping
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from inspect import signature, Parameter
|
from inspect import Parameter, signature
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
|
|
||||||
from .compat import DataFrame, py_str, PANDAS_INSTALLED
|
|
||||||
from .libpath import find_lib_path
|
|
||||||
from ._typing import (
|
from ._typing import (
|
||||||
CStrPptr,
|
_T,
|
||||||
c_bst_ulong,
|
ArrayLike,
|
||||||
|
BoosterParam,
|
||||||
|
CFloatPtr,
|
||||||
CNumeric,
|
CNumeric,
|
||||||
DataType,
|
|
||||||
CNumericPtr,
|
CNumericPtr,
|
||||||
|
CStrPptr,
|
||||||
CStrPtr,
|
CStrPtr,
|
||||||
CTypeT,
|
CTypeT,
|
||||||
ArrayLike,
|
|
||||||
CFloatPtr,
|
|
||||||
NumpyOrCupy,
|
|
||||||
FeatureInfo,
|
|
||||||
FeatureTypes,
|
|
||||||
FeatureNames,
|
|
||||||
_T,
|
|
||||||
CupyT,
|
CupyT,
|
||||||
BoosterParam
|
DataType,
|
||||||
|
FeatureInfo,
|
||||||
|
FeatureNames,
|
||||||
|
FeatureTypes,
|
||||||
|
NumpyOrCupy,
|
||||||
|
c_bst_ulong,
|
||||||
)
|
)
|
||||||
|
from .compat import PANDAS_INSTALLED, DataFrame, py_str
|
||||||
|
from .libpath import find_lib_path
|
||||||
|
|
||||||
|
|
||||||
class XGBoostError(ValueError):
|
class XGBoostError(ValueError):
|
||||||
"""Error thrown by xgboost trainer."""
|
"""Error thrown by xgboost trainer."""
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def from_pystr_to_cstr(data: str) -> bytes:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def from_pystr_to_cstr(data: List[str]) -> ctypes.Array:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, ctypes.Array]:
|
def from_pystr_to_cstr(data: Union[str, List[str]]) -> Union[bytes, ctypes.Array]:
|
||||||
"""Convert a Python str or list of Python str to C pointer
|
"""Convert a Python str or list of Python str to C pointer
|
||||||
|
|
||||||
|
|||||||
@ -3,24 +3,33 @@
|
|||||||
'''Data dispatching for DMatrix.'''
|
'''Data dispatching for DMatrix.'''
|
||||||
import ctypes
|
import ctypes
|
||||||
import json
|
import json
|
||||||
import warnings
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Tuple, Callable, Optional, List, Union, Iterator, Sequence, cast
|
import warnings
|
||||||
|
from typing import Any, Callable, Iterator, List, Optional, Sequence, Tuple, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .core import c_array, _LIB, _check_call, c_str
|
|
||||||
from .core import _cuda_array_interface
|
|
||||||
from .core import DataIter, _ProxyDMatrix, DMatrix
|
|
||||||
from .compat import lazy_isinstance, DataFrame
|
|
||||||
from ._typing import (
|
from ._typing import (
|
||||||
c_bst_ulong,
|
|
||||||
DataType,
|
|
||||||
FeatureTypes,
|
|
||||||
FeatureNames,
|
|
||||||
NumpyDType,
|
|
||||||
CupyT,
|
CupyT,
|
||||||
FloatCompatible, PandasDType
|
DataType,
|
||||||
|
FeatureNames,
|
||||||
|
FeatureTypes,
|
||||||
|
FloatCompatible,
|
||||||
|
NumpyDType,
|
||||||
|
PandasDType,
|
||||||
|
c_bst_ulong,
|
||||||
|
)
|
||||||
|
from .compat import DataFrame, lazy_isinstance
|
||||||
|
from .core import (
|
||||||
|
_LIB,
|
||||||
|
DataIter,
|
||||||
|
DMatrix,
|
||||||
|
_check_call,
|
||||||
|
_cuda_array_interface,
|
||||||
|
_ProxyDMatrix,
|
||||||
|
c_array,
|
||||||
|
c_str,
|
||||||
|
from_pystr_to_cstr,
|
||||||
)
|
)
|
||||||
|
|
||||||
DispatchedDataBackendReturnType = Tuple[
|
DispatchedDataBackendReturnType = Tuple[
|
||||||
@ -631,10 +640,10 @@ def _is_cudf_df(data: DataType) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes:
|
def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes:
|
||||||
"""Extract CuDF __cuda_array_interface__. This is special as it returns a new list of
|
"""Extract CuDF __cuda_array_interface__. This is special as it returns a new list
|
||||||
data and a list of array interfaces. The data is list of categorical codes that
|
of data and a list of array interfaces. The data is list of categorical codes that
|
||||||
caller can safely ignore, but have to keep their reference alive until usage of array
|
caller can safely ignore, but have to keep their reference alive until usage of
|
||||||
interface is finished.
|
array interface is finished.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
@ -643,14 +652,18 @@ def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes:
|
|||||||
from cudf.utils.dtypes import is_categorical_dtype
|
from cudf.utils.dtypes import is_categorical_dtype
|
||||||
|
|
||||||
interfaces = []
|
interfaces = []
|
||||||
|
|
||||||
|
def append(interface: dict) -> None:
|
||||||
|
if "mask" in interface:
|
||||||
|
interface["mask"] = interface["mask"].__cuda_array_interface__
|
||||||
|
interfaces.append(interface)
|
||||||
|
|
||||||
if _is_cudf_ser(data):
|
if _is_cudf_ser(data):
|
||||||
if is_categorical_dtype(data.dtype):
|
if is_categorical_dtype(data.dtype):
|
||||||
interface = cat_codes[0].__cuda_array_interface__
|
interface = cat_codes[0].__cuda_array_interface__
|
||||||
else:
|
else:
|
||||||
interface = data.__cuda_array_interface__
|
interface = data.__cuda_array_interface__
|
||||||
if "mask" in interface:
|
append(interface)
|
||||||
interface["mask"] = interface["mask"].__cuda_array_interface__
|
|
||||||
interfaces.append(interface)
|
|
||||||
else:
|
else:
|
||||||
for i, col in enumerate(data):
|
for i, col in enumerate(data):
|
||||||
if is_categorical_dtype(data[col].dtype):
|
if is_categorical_dtype(data[col].dtype):
|
||||||
@ -658,10 +671,8 @@ def _cudf_array_interfaces(data: DataType, cat_codes: list) -> bytes:
|
|||||||
interface = codes.__cuda_array_interface__
|
interface = codes.__cuda_array_interface__
|
||||||
else:
|
else:
|
||||||
interface = data[col].__cuda_array_interface__
|
interface = data[col].__cuda_array_interface__
|
||||||
if "mask" in interface:
|
append(interface)
|
||||||
interface["mask"] = interface["mask"].__cuda_array_interface__
|
interfaces_str = from_pystr_to_cstr(json.dumps(interfaces))
|
||||||
interfaces.append(interface)
|
|
||||||
interfaces_str = bytes(json.dumps(interfaces, indent=2), "utf-8")
|
|
||||||
return interfaces_str
|
return interfaces_str
|
||||||
|
|
||||||
|
|
||||||
@ -722,9 +733,14 @@ def _transform_cudf_df(
|
|||||||
cat_codes.append(codes)
|
cat_codes.append(codes)
|
||||||
else:
|
else:
|
||||||
for col in data:
|
for col in data:
|
||||||
if is_categorical_dtype(data[col].dtype) and enable_categorical:
|
dtype = data[col].dtype
|
||||||
|
if is_categorical_dtype(dtype) and enable_categorical:
|
||||||
codes = data[col].cat.codes
|
codes = data[col].cat.codes
|
||||||
cat_codes.append(codes)
|
cat_codes.append(codes)
|
||||||
|
elif is_categorical_dtype(dtype):
|
||||||
|
raise ValueError(_ENABLE_CAT_ERR)
|
||||||
|
else:
|
||||||
|
cat_codes.append([])
|
||||||
|
|
||||||
return data, cat_codes, feature_names, feature_types
|
return data, cat_codes, feature_names, feature_types
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import sys
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
@ -176,20 +178,38 @@ Arrow specification.'''
|
|||||||
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
|
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_cudf_categorical(self):
|
def test_cudf_categorical(self) -> None:
|
||||||
import cudf
|
import cudf
|
||||||
_X, _y = tm.make_categorical(100, 30, 17, False)
|
n_features = 30
|
||||||
|
_X, _y = tm.make_categorical(100, n_features, 17, False)
|
||||||
X = cudf.from_pandas(_X)
|
X = cudf.from_pandas(_X)
|
||||||
y = cudf.from_pandas(_y)
|
y = cudf.from_pandas(_y)
|
||||||
|
|
||||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||||
|
assert Xy.feature_types is not None
|
||||||
assert len(Xy.feature_types) == X.shape[1]
|
assert len(Xy.feature_types) == X.shape[1]
|
||||||
assert all(t == "c" for t in Xy.feature_types)
|
assert all(t == "c" for t in Xy.feature_types)
|
||||||
|
|
||||||
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
|
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
|
||||||
|
assert Xy.feature_types is not None
|
||||||
assert len(Xy.feature_types) == X.shape[1]
|
assert len(Xy.feature_types) == X.shape[1]
|
||||||
assert all(t == "c" for t in Xy.feature_types)
|
assert all(t == "c" for t in Xy.feature_types)
|
||||||
|
|
||||||
|
# mixed dtypes
|
||||||
|
X["1"] = X["1"].astype(np.int64)
|
||||||
|
X["3"] = X["3"].astype(np.int64)
|
||||||
|
df, cat_codes, _, _ = xgb.data._transform_cudf_df(
|
||||||
|
X, None, None, enable_categorical=True
|
||||||
|
)
|
||||||
|
assert X.shape[1] == n_features
|
||||||
|
assert len(cat_codes) == X.shape[1]
|
||||||
|
assert not cat_codes[0]
|
||||||
|
assert not cat_codes[2]
|
||||||
|
|
||||||
|
interfaces_str = xgb.data._cudf_array_interfaces(df, cat_codes)
|
||||||
|
interfaces = json.loads(interfaces_str)
|
||||||
|
assert len(interfaces) == X.shape[1]
|
||||||
|
|
||||||
# test missing value
|
# test missing value
|
||||||
X = cudf.DataFrame({"f0": ["a", "b", np.NaN]})
|
X = cudf.DataFrame({"f0": ["a", "b", np.NaN]})
|
||||||
X["f0"] = X["f0"].astype("category")
|
X["f0"] = X["f0"].astype("category")
|
||||||
@ -206,7 +226,7 @@ Arrow specification.'''
|
|||||||
assert Xy.num_row() == 3
|
assert Xy.num_row() == 3
|
||||||
assert Xy.num_col() == 1
|
assert Xy.num_col() == 1
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="enable_categorical"):
|
||||||
xgb.DeviceQuantileDMatrix(X, y)
|
xgb.DeviceQuantileDMatrix(X, y)
|
||||||
|
|
||||||
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
|
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user