Implement Python data handler. (#5689)

* Define data handlers for DMatrix.
* Throw ValueError in scikit learn interface.
This commit is contained in:
Jiaming Yuan 2020-05-22 11:53:55 +08:00 committed by GitHub
parent 646def51e0
commit 5af8161a1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 746 additions and 405 deletions

View File

@ -14,7 +14,7 @@ from sklearn.datasets import load_iris, load_digits, load_boston
rng = np.random.RandomState(31337) rng = np.random.RandomState(31337)
print("Zeros and Ones from the Digits dataset: binary classification") print("Zeros and Ones from the Digits dataset: binary classification")
digits = load_digits(2) digits = load_digits(n_class=2)
y = digits['target'] y = digits['target']
X = digits['data'] X = digits['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng) kf = KFold(n_splits=2, shuffle=True, random_state=rng)

View File

@ -107,7 +107,6 @@ except ImportError:
try: try:
from cudf import DataFrame as CUDF_DataFrame from cudf import DataFrame as CUDF_DataFrame
from cudf import Series as CUDF_Series from cudf import Series as CUDF_Series
from cudf import MultiIndex as CUDF_MultiIndex
from cudf import concat as CUDF_concat from cudf import concat as CUDF_concat
CUDF_INSTALLED = True CUDF_INSTALLED = True
except ImportError: except ImportError:

View File

@ -1,7 +1,6 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches, invalid-name # pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-branches, too-many-lines, too-many-locals # pylint: disable=too-many-lines, too-many-locals
# pylint: disable=too-many-public-methods
"""Core XGBoost Library.""" """Core XGBoost Library."""
import collections import collections
# pylint: disable=no-name-in-module,import-error # pylint: disable=no-name-in-module,import-error
@ -11,16 +10,15 @@ import ctypes
import os import os
import re import re
import sys import sys
import warnings
import json import json
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .compat import ( from .compat import (
STRING_TYPES, DataFrame, MultiIndex, Int64Index, py_str, STRING_TYPES, DataFrame, py_str,
PANDAS_INSTALLED, CUDF_INSTALLED, PANDAS_INSTALLED, CUDF_INSTALLED,
CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex, CUDF_DataFrame,
os_fspath, os_PathLike, lazy_isinstance) os_fspath, os_PathLike, lazy_isinstance)
from .libpath import find_lib_path from .libpath import find_lib_path
@ -262,10 +260,24 @@ def c_array(ctype, values):
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': def _convert_unknown_data(data, meta=None, meta_type=None):
'int', 'uint8': 'int', 'uint16': 'int', 'uint32': 'int', if meta is not None:
'uint64': 'int', 'float16': 'float', 'float32': 'float', try:
'float64': 'float', 'bool': 'i'} data = np.array(data, dtype=meta_type)
except Exception:
raise TypeError('Can not handle data from {}'.format(
type(data).__name__))
else:
import warnings
warnings.warn(
'Unknown data type: ' + str(type(data)) +
', coverting it to csr_matrix')
try:
data = scipy.sparse.csr_matrix(data)
except Exception:
raise TypeError('Can not initialize DMatrix from'
' {}'.format(type(data).__name__))
return data
# Either object has cuda array interface or contains columns with interfaces # Either object has cuda array interface or contains columns with interfaces
@ -274,57 +286,12 @@ def _has_cuda_array_interface(data):
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame)) CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))
def _maybe_pandas_data(data, feature_names, feature_types,
meta=None, meta_type=None):
"""Extract internal data from pd.DataFrame for DMatrix data"""
if not (PANDAS_INSTALLED and isinstance(data, DataFrame)):
return data, feature_names, feature_types
from pandas.api.types import is_sparse
data_dtypes = data.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER or is_sparse(dtype)
for dtype in data_dtypes):
bad_fields = [
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
if dtype.name not in PANDAS_DTYPE_MAPPER
]
msg = """DataFrame.dtypes for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None and meta is None:
if isinstance(data.columns, MultiIndex):
feature_names = [
' '.join([str(x) for x in i]) for i in data.columns
]
elif isinstance(data.columns, Int64Index):
feature_names = list(map(str, data.columns))
else:
feature_names = data.columns.format()
if feature_types is None and meta is None:
feature_types = []
for dtype in data_dtypes:
if is_sparse(dtype):
feature_types.append(PANDAS_DTYPE_MAPPER[dtype.subtype.name])
else:
feature_types.append(PANDAS_DTYPE_MAPPER[dtype.name])
if meta and len(data.columns) > 1:
raise ValueError(
'DataFrame for {meta} cannot have multiple columns'.format(
meta=meta))
dtype = meta_type if meta_type else 'float'
data = data.values.astype(dtype)
return data, feature_names, feature_types
def _cudf_array_interfaces(df): def _cudf_array_interfaces(df):
'''Extract CuDF __cuda_array_interface__''' '''Extract CuDF __cuda_array_interface__'''
interfaces = [] interfaces = []
if lazy_isinstance(df, 'cudf.core.series', 'Series'):
interfaces.append(df.__cuda_array_interface__)
else:
for col in df: for col in df:
interface = df[col].__cuda_array_interface__ interface = df[col].__cuda_array_interface__
if 'mask' in interface: if 'mask' in interface:
@ -334,124 +301,7 @@ def _cudf_array_interfaces(df):
return interfaces_str return interfaces_str
def _maybe_cudf_dataframe(data, feature_names, feature_types): class DMatrix: # pylint: disable=too-many-instance-attributes
"""Extract internal data from cudf.DataFrame for DMatrix data."""
if not (CUDF_INSTALLED and isinstance(data,
(CUDF_DataFrame, CUDF_Series))):
return data, feature_names, feature_types
if feature_names is None:
if isinstance(data, CUDF_Series):
feature_names = [data.name]
elif isinstance(data.columns, CUDF_MultiIndex):
feature_names = [
' '.join([str(x) for x in i])
for i in data.columns
]
else:
feature_names = data.columns.format()
if feature_types is None:
if isinstance(data, CUDF_Series):
dtypes = [data.dtype]
else:
dtypes = data.dtypes
feature_types = [PANDAS_DTYPE_MAPPER[d.name] for d in dtypes]
return data, feature_names, feature_types
DT_TYPE_MAPPER = {'bool': 'bool', 'int': 'int', 'real': 'float'}
DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
def _maybe_dt_data(data, feature_names, feature_types,
meta=None, meta_type=None):
"""Validate feature names and types if data table"""
if (not lazy_isinstance(data, 'datatable', 'Frame') and
not lazy_isinstance(data, 'datatable', 'DataTable')):
return data, feature_names, feature_types
if meta and data.shape[1] > 1:
raise ValueError(
'DataTable for label or weight cannot have multiple columns')
if meta:
# below requires new dt version
# extract first column
data = data.to_numpy()[:, 0].astype(meta_type)
return data, None, None
data_types_names = tuple(lt.name for lt in data.ltypes)
bad_fields = [data.names[i]
for i, type_name in enumerate(data_types_names)
if type_name not in DT_TYPE_MAPPER]
if bad_fields:
msg = """DataFrame.types for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None and meta is None:
feature_names = data.names
# always return stypes for dt ingestion
if feature_types is not None:
raise ValueError(
'DataTable has own feature types, cannot pass them in.')
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
return data, feature_names, feature_types
def _is_dlpack(x):
return 'PyCapsule' in str(type(x)) and "dltensor" in str(x)
# Just convert dlpack into cupy (zero copy)
def _maybe_dlpack_data(data, feature_names, feature_types):
if not _is_dlpack(data):
return data, feature_names, feature_types
from cupy import fromDlpack # pylint: disable=E0401
data = fromDlpack(data)
return data, feature_names, feature_types
def _convert_dataframes(data, feature_names, feature_types,
meta=None, meta_type=None):
data, feature_names, feature_types = _maybe_pandas_data(data,
feature_names,
feature_types,
meta,
meta_type)
data, feature_names, feature_types = _maybe_dt_data(data,
feature_names,
feature_types,
meta,
meta_type)
data, feature_names, feature_types = _maybe_cudf_dataframe(
data, feature_names, feature_types)
data, feature_names, feature_types = _maybe_dlpack_data(
data, feature_names, feature_types)
return data, feature_names, feature_types
def _maybe_np_slice(data, dtype=np.float32):
'''Handle numpy slice. This can be removed if we use __array_interface__.
'''
try:
if not data.flags.c_contiguous:
warnings.warn(
"Use subset (sliced data) of np.ndarray is not recommended " +
"because it will generate extra copies and increase " +
"memory consumption")
data = np.array(data, copy=True, dtype=dtype)
else:
data = np.array(data, copy=False, dtype=dtype)
except AttributeError:
data = np.array(data, copy=False, dtype=dtype)
return data
class DMatrix(object):
"""Data Matrix used in XGBoost. """Data Matrix used in XGBoost.
DMatrix is a internal data structure that used by XGBoost DMatrix is a internal data structure that used by XGBoost
@ -503,6 +353,13 @@ class DMatrix(object):
applicable. If -1, uses maximum threads available on the system. applicable. If -1, uses maximum threads available on the system.
""" """
if isinstance(data, list):
raise TypeError('Input data can not be a list.')
self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else 1
self.silent = silent
# force into void_p, mac need to pass things in as void_p # force into void_p, mac need to pass things in as void_p
if data is None: if data is None:
self.handle = None self.handle = None
@ -513,40 +370,13 @@ class DMatrix(object):
self._feature_types = feature_types self._feature_types = feature_types
return return
if isinstance(data, list): handler = self.get_data_handler(data)
raise TypeError('Input data can not be a list.') if handler is None:
data = _convert_unknown_data(data, None)
data, feature_names, feature_types = _convert_dataframes( handler = self.get_data_handler(data)
data, feature_names, feature_types self.handle, feature_names, feature_types = handler.handle_input(
) data, feature_names, feature_types)
missing = missing if missing is not None else np.nan assert self.handle, 'Failed to construct a DMatrix.'
nthread = nthread if nthread is not None else 1
if isinstance(data, (STRING_TYPES, os_PathLike)):
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
ctypes.c_int(silent),
ctypes.byref(handle)))
self.handle = handle
elif isinstance(data, scipy.sparse.csr_matrix):
self._init_from_csr(data)
elif isinstance(data, scipy.sparse.csc_matrix):
self._init_from_csc(data)
elif isinstance(data, np.ndarray):
self._init_from_npy2d(data, missing, nthread)
elif lazy_isinstance(data, 'datatable', 'Frame'):
self._init_from_dt(data, nthread)
elif hasattr(data, "__cuda_array_interface__"):
self._init_from_array_interface(data, missing, nthread)
elif CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
self._init_from_array_interface_columns(data, missing, nthread)
else:
try:
csr = scipy.sparse.csr_matrix(data)
self._init_from_csr(csr)
except Exception:
raise TypeError('can not initialize DMatrix from'
' {}'.format(type(data).__name__))
if label is not None: if label is not None:
self.set_label(label) self.set_label(label)
@ -558,126 +388,12 @@ class DMatrix(object):
self.feature_names = feature_names self.feature_names = feature_names
self.feature_types = feature_types self.feature_types = feature_types
def _init_from_csr(self, csr): def get_data_handler(self, data, meta=None, meta_type=None):
"""Initialize data from a CSR matrix.""" '''Get data handler for this DMatrix class.'''
if len(csr.indices) != len(csr.data): from .data import get_dmatrix_data_handler
raise ValueError('length mismatch: {} vs {}'.format( handler = get_dmatrix_data_handler(
len(csr.indices), len(csr.data))) data, self.missing, self.nthread, self.silent, meta, meta_type)
handle = ctypes.c_void_p() return handler
_check_call(_LIB.XGDMatrixCreateFromCSREx(
c_array(ctypes.c_size_t, csr.indptr),
c_array(ctypes.c_uint, csr.indices),
c_array(ctypes.c_float, csr.data),
ctypes.c_size_t(len(csr.indptr)),
ctypes.c_size_t(len(csr.data)),
ctypes.c_size_t(csr.shape[1]),
ctypes.byref(handle)))
self.handle = handle
def _init_from_csc(self, csc):
"""Initialize data from a CSC matrix."""
if len(csc.indices) != len(csc.data):
raise ValueError('length mismatch: {} vs {}'.format(
len(csc.indices), len(csc.data)))
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSCEx(
c_array(ctypes.c_size_t, csc.indptr),
c_array(ctypes.c_uint, csc.indices),
c_array(ctypes.c_float, csc.data),
ctypes.c_size_t(len(csc.indptr)),
ctypes.c_size_t(len(csc.data)),
ctypes.c_size_t(csc.shape[0]),
ctypes.byref(handle)))
self.handle = handle
def _init_from_npy2d(self, mat, missing, nthread):
"""Initialize data from a 2-D numpy matrix.
If ``mat`` does not have ``order='C'`` (aka row-major) or is
not contiguous, a temporary copy will be made.
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will
be made.
So there could be as many as two temporary data copies; be mindful of
input layout and type if memory use is a concern.
"""
if len(mat.shape) != 2:
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
mat.shape)
# flatten the array by rows and ensure it is float32. we try to avoid
# data copies if possible (reshape returns a view when possible and we
# explicitly tell np.array to try and avoid copying)
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(mat.shape[0]),
c_bst_ulong(mat.shape[1]),
ctypes.c_float(missing),
ctypes.byref(handle),
ctypes.c_int(nthread)))
self.handle = handle
def _init_from_dt(self, data, nthread):
"""Initialize data from a datatable Frame."""
ptrs = (ctypes.c_void_p * data.ncols)()
if hasattr(data, "internal") and hasattr(data.internal, "column"):
# datatable>0.8.0
for icol in range(data.ncols):
col = data.internal.column(icol)
ptr = col.data_pointer
ptrs[icol] = ctypes.c_void_p(ptr)
else:
# datatable<=0.8.0
from datatable.internal import \
frame_column_data_r # pylint: disable=no-name-in-module,import-error
for icol in range(data.ncols):
ptrs[icol] = frame_column_data_r(data, icol)
# always return stypes for dt ingestion
feature_type_strings = (ctypes.c_char_p * data.ncols)()
for icol in range(data.ncols):
feature_type_strings[icol] = ctypes.c_char_p(
data.stypes[icol].name.encode('utf-8'))
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromDT(
ptrs, feature_type_strings,
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.byref(handle),
ctypes.c_int(nthread)))
self.handle = handle
def _init_from_array_interface_columns(self, df, missing, nthread):
"""Initialize DMatrix from columnar memory format."""
interfaces_str = _cudf_array_interfaces(df)
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(missing),
ctypes.c_int(nthread),
ctypes.byref(handle)))
self.handle = handle
def _init_from_array_interface(self, data, missing, nthread):
"""Initialize DMatrix from cupy ndarray."""
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(missing),
ctypes.c_int(nthread),
ctypes.byref(handle)))
self.handle = handle
def __del__(self): def __del__(self):
if hasattr(self, "handle") and self.handle: if hasattr(self, "handle") and self.handle:
@ -737,10 +453,14 @@ class DMatrix(object):
data: numpy array data: numpy array
The array of data to be set The array of data to be set
""" """
data, _, _ = _convert_dataframes(data, None, None, field, 'float')
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
self.set_float_info_npy2d(field, data) self.set_float_info_npy2d(field, data)
return return
handler = self.get_data_handler(data, field, np.float32)
if handler is None:
data = _convert_unknown_data(data, field, np.float32)
handler = self.get_data_handler(data, field, np.float32)
data, _, _ = handler.transform(data)
c_data = c_array(ctypes.c_float, data) c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle, _check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field), c_str(field),
@ -759,7 +479,8 @@ class DMatrix(object):
data: numpy array data: numpy array
The array of data to be set The array of data to be set
""" """
data = _maybe_np_slice(data, np.float32) data, _, _ = self.get_data_handler(
data, field, np.float32).transform(data)
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle, _check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field), c_str(field),
@ -777,9 +498,8 @@ class DMatrix(object):
data: numpy array data: numpy array
The array of data to be set The array of data to be set
""" """
data = _maybe_np_slice(data, np.uint32) data, _, _ = self.get_data_handler(
data, _, _ = _convert_dataframes(data, None, None, field, 'uint32') data, field, 'uint32').transform(data)
data = np.array(data, copy=False, dtype=ctypes.c_uint)
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle, _check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
c_str(field), c_str(field),
c_array(ctypes.c_uint, data), c_array(ctypes.c_uint, data),
@ -1075,46 +795,18 @@ class DeviceQuantileDMatrix(DMatrix):
feature_types=None, feature_types=None,
nthread=None, max_bin=256): nthread=None, max_bin=256):
self.max_bin = max_bin self.max_bin = max_bin
if not (hasattr(data, "__cuda_array_interface__") or ( super().__init__(data, label=label, weight=weight,
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame)) or _is_dlpack(data)): base_margin=base_margin,
raise ValueError('Only cupy/cudf/dlpack currently supported for DeviceQuantileDMatrix')
super().__init__(data, label=label, weight=weight, base_margin=base_margin,
missing=missing, missing=missing,
silent=silent, silent=silent,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
nthread=nthread) nthread=nthread)
def _init_from_array_interface_columns(self, df, missing, nthread): def get_data_handler(self, data, meta=None, meta_type=None):
"""Initialize DMatrix from columnar memory format.""" from .data import get_device_quantile_dmatrix_data_handler
interfaces_str = _cudf_array_interfaces(df) return get_device_quantile_dmatrix_data_handler(
handle = ctypes.c_void_p() data, self.max_bin, self.missing, self.nthread, self.silent)
missing = missing if missing is not None else np.nan
nthread = nthread if nthread is not None else 1
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(missing), ctypes.c_int(nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
self.handle = handle
def _init_from_array_interface(self, data, missing, nthread):
"""Initialize DMatrix from cupy ndarray."""
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan
nthread = nthread if nthread is not None else 1
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(missing), ctypes.c_int(nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
self.handle = handle
class Booster(object): class Booster(object):
@ -1467,6 +1159,7 @@ class Booster(object):
self._validate_features(data) self._validate_features(data)
return self.eval_set([(data, name)], iteration) return self.eval_set([(data, name)], iteration)
# pylint: disable=too-many-function-args
def predict(self, def predict(self,
data, data,
output_margin=False, output_margin=False,

View File

@ -0,0 +1,624 @@
# pylint: disable=too-many-arguments, no-self-use
'''Data dispatching for DMatrix.'''
import ctypes
import abc
import json
import warnings
import numpy as np
from .core import c_array, _LIB, _check_call, c_str, _cudf_array_interfaces
from .compat import lazy_isinstance, STRING_TYPES, os_fspath, os_PathLike
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
class DataHandler(abc.ABC):
'''Base class for various data handler.'''
def __init__(self, missing, nthread, silent, meta=None, meta_type=None):
self.missing = missing
self.nthread = nthread
self.silent = silent
self.meta = meta
self.meta_type = meta_type
def _warn_unused_missing(self, data):
if not (np.isnan(np.nan) or None):
warnings.warn(
'`missing` is not used for current input data type:' +
str(type(data)))
def check_complex(self, data):
'''Test whether data is complex using `dtype` attribute.'''
complex_dtypes = (np.complex128, np.complex64,
np.cfloat, np.cdouble, np.clongdouble)
if hasattr(data, 'dtype') and data.dtype in complex_dtypes:
raise ValueError('Complex data not supported')
def transform(self, data):
'''Optional method for transforming data before being accepted by
other XGBoost API.'''
return data, None, None
@abc.abstractmethod
def handle_input(self, data, feature_names, feature_types):
'''Abstract method for handling different data input.'''
class DMatrixDataManager:
'''The registry class for various data handler.'''
def __init__(self):
self.__data_handlers = {}
self.__data_handlers_dly = []
def register_handler(self, module, name, handler):
'''Register a data handler handling specfic type of data.'''
self.__data_handlers['.'.join([module, name])] = handler
def register_handler_opaque(self, func, handler):
'''Register a data handler that handles data with opaque type.
Parameters
----------
func : callable
A function with a single parameter `data`. It should return True
if the handler can handle this data, otherwise returns False.
handler : xgboost.data.DataHandler
The handler class that is a subclass of `DataHandler`.
'''
self.__data_handlers_dly.append((func, handler))
def get_handler(self, data):
'''Get a handler of `data`, returns None if handler not found.'''
module, name = type(data).__module__, type(data).__name__
if '.'.join([module, name]) in self.__data_handlers.keys():
handler = self.__data_handlers['.'.join([module, name])]
return handler
for f, handler in self.__data_handlers_dly:
if f(data):
return handler
return None
__dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name
def get_dmatrix_data_handler(data, missing, nthread, silent,
meta=None, meta_type=None):
'''Get a handler of `data` for DMatrix.
.. versionadded:: 1.2.0
Parameters
----------
data : any
The input data.
missing : float
Same as `missing` for DMatrix.
nthread : int
Same as `nthread` for DMatrix.
silent : boolean
Same as `silent` for DMatrix.
meta : str
Field name of meta data, like `label`. Used only for getting handler
for meta info.
meta_type : str/np.dtype
Type of meta data.
Returns
-------
handler : DataHandler
'''
handler = __dmatrix_registry.get_handler(data)
if handler is None:
return None
return handler(missing, nthread, silent, meta, meta_type)
class FileHandler(DataHandler):
'''Handler of path like input.'''
def handle_input(self, data, feature_names, feature_types):
self._warn_unused_missing(data)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
ctypes.c_int(self.silent),
ctypes.byref(handle)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler_opaque(
lambda data: isinstance(data, (STRING_TYPES, os_PathLike)),
FileHandler)
class CSRHandler(DataHandler):
'''Handler of `scipy.sparse.csr.csr_matrix`.'''
def handle_input(self, data, feature_names, feature_types):
'''Initialize data from a CSR matrix.'''
if len(data.indices) != len(data.data):
raise ValueError('length mismatch: {} vs {}'.format(
len(data.indices), len(data.data)))
self._warn_unused_missing(data)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSREx(
c_array(ctypes.c_size_t, data.indptr),
c_array(ctypes.c_uint, data.indices),
c_array(ctypes.c_float, data.data),
ctypes.c_size_t(len(data.indptr)),
ctypes.c_size_t(len(data.data)),
ctypes.c_size_t(data.shape[1]),
ctypes.byref(handle)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler(
'scipy.sparse.csr', 'csr_matrix', CSRHandler)
class CSCHandler(DataHandler):
'''Handler of `scipy.sparse.csc.csc_matrix`.'''
def handle_input(self, data, feature_names, feature_types):
if len(data.indices) != len(data.data):
raise ValueError('length mismatch: {} vs {}'.format(
len(data.indices), len(data.data)))
self._warn_unused_missing(data)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSCEx(
c_array(ctypes.c_size_t, data.indptr),
c_array(ctypes.c_uint, data.indices),
c_array(ctypes.c_float, data.data),
ctypes.c_size_t(len(data.indptr)),
ctypes.c_size_t(len(data.data)),
ctypes.c_size_t(data.shape[0]),
ctypes.byref(handle)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler(
'scipy.sparse.csc', 'csc_matrix', CSCHandler)
class NumpyHandler(DataHandler):
'''Handler of `numpy.ndarray`.'''
def _maybe_np_slice(self, data, dtype):
'''Handle numpy slice. This can be removed if we use __array_interface__.
'''
try:
if not data.flags.c_contiguous:
warnings.warn(
"Use subset (sliced data) of np.ndarray is not recommended " +
"because it will generate extra copies and increase " +
"memory consumption")
data = np.array(data, copy=True, dtype=dtype)
else:
data = np.array(data, copy=False, dtype=dtype)
except AttributeError:
data = np.array(data, copy=False, dtype=dtype)
return data
def transform(self, data):
return self._maybe_np_slice(data, self.meta_type), None, None
def handle_input(self, data, feature_names, feature_types):
"""Initialize data from a 2-D numpy matrix.
If ``mat`` does not have ``order='C'`` (aka row-major) or is
not contiguous, a temporary copy will be made.
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will
be made.
So there could be as many as two temporary data copies; be mindful of
input layout and type if memory use is a concern.
"""
if not isinstance(data, np.ndarray) and hasattr(data, '__array__'):
data = np.array(data, copy=False)
if len(data.shape) != 2:
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
data.shape)
# flatten the array by rows and ensure it is float32. we try to avoid
# data copies if possible (reshape returns a view when possible and we
# explicitly tell np.array to try and avoid copying)
flatten = np.array(data.reshape(data.size), copy=False,
dtype=np.float32)
flatten = self._maybe_np_slice(flatten, np.float32)
self.check_complex(data)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.c_float(self.missing),
ctypes.byref(handle),
ctypes.c_int(self.nthread)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler('numpy', 'ndarray', NumpyHandler)
__dmatrix_registry.register_handler('numpy', 'matrix', NumpyHandler)
__dmatrix_registry.register_handler_opaque(
lambda x: hasattr(x, '__array__'), NumpyHandler)
class ListHandler(NumpyHandler):
'''Handler of builtin list and tuple'''
def handle_input(self, data, feature_names, feature_types):
assert self.meta is None, 'List input data is not supported for X'
data = np.array(data)
return super().handle_input(data, feature_names, feature_types)
__dmatrix_registry.register_handler('builtins', 'list', NumpyHandler)
__dmatrix_registry.register_handler('builtins', 'tuple', NumpyHandler)
class PandasHandler(NumpyHandler):
'''Handler of data structures defined by `pandas`.'''
pandas_dtype_mapper = {
'int8': 'int',
'int16': 'int',
'int32': 'int',
'int64': 'int',
'uint8': 'int',
'uint16': 'int',
'uint32': 'int',
'uint64': 'int',
'float16': 'float',
'float32': 'float',
'float64': 'float',
'bool': 'i'
}
def _maybe_pandas_data(self, data, feature_names, feature_types,
meta=None, meta_type=None):
"""Extract internal data from pd.DataFrame for DMatrix data"""
if lazy_isinstance(data, 'pandas.core.series', 'Series'):
dtype = meta_type if meta_type else 'float'
return data.values.astype(dtype), feature_names, feature_types
from pandas.api.types import is_sparse
from pandas import MultiIndex, Int64Index
data_dtypes = data.dtypes
if not all(dtype.name in self.pandas_dtype_mapper or is_sparse(dtype)
for dtype in data_dtypes):
bad_fields = [
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
if dtype.name not in self.pandas_dtype_mapper
]
msg = """DataFrame.dtypes for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None and meta is None:
if isinstance(data.columns, MultiIndex):
feature_names = [
' '.join([str(x) for x in i]) for i in data.columns
]
elif isinstance(data.columns, Int64Index):
feature_names = list(map(str, data.columns))
else:
feature_names = data.columns.format()
if feature_types is None and meta is None:
feature_types = []
for dtype in data_dtypes:
if is_sparse(dtype):
feature_types.append(self.pandas_dtype_mapper[
dtype.subtype.name])
else:
feature_types.append(self.pandas_dtype_mapper[dtype.name])
if meta and len(data.columns) > 1:
raise ValueError(
'DataFrame for {meta} cannot have multiple columns'.format(
meta=meta))
dtype = meta_type if meta_type else 'float'
data = data.values.astype(dtype)
return data, feature_names, feature_types
def transform(self, data):
return self._maybe_pandas_data(data, None, None, self.meta,
self.meta_type)
def handle_input(self, data, feature_names, feature_types):
data, feature_names, feature_types = self._maybe_pandas_data(
data, feature_names, feature_types, self.meta, self.meta_type)
return super().handle_input(data, feature_names, feature_types)
__dmatrix_registry.register_handler(
'pandas.core.frame', 'DataFrame', PandasHandler)
__dmatrix_registry.register_handler(
'pandas.core.series', 'Series', PandasHandler)
class DTHandler(DataHandler):
'''Handler of datatable.'''
dt_type_mapper = {'bool': 'bool', 'int': 'int', 'real': 'float'}
dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
def _maybe_dt_data(self, data, feature_names, feature_types,
meta=None, meta_type=None):
"""Validate feature names and types if data table"""
if meta and data.shape[1] > 1:
raise ValueError(
'DataTable for label or weight cannot have multiple columns')
if meta:
# below requires new dt version
# extract first column
data = data.to_numpy()[:, 0].astype(meta_type)
return data, None, None
data_types_names = tuple(lt.name for lt in data.ltypes)
bad_fields = [data.names[i]
for i, type_name in enumerate(data_types_names)
if type_name not in self.dt_type_mapper]
if bad_fields:
msg = """DataFrame.types for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None and meta is None:
feature_names = data.names
# always return stypes for dt ingestion
if feature_types is not None:
raise ValueError(
'DataTable has own feature types, cannot pass them in.')
feature_types = np.vectorize(self.dt_type_mapper2.get)(
data_types_names)
return data, feature_names, feature_types
def transform(self, data):
return self._maybe_dt_data(data, None, None, self.meta, self.meta_type)
def handle_input(self, data, feature_names, feature_types):
data, feature_names, feature_types = self._maybe_dt_data(
data, feature_names, feature_types, self.meta, self.meta_type)
ptrs = (ctypes.c_void_p * data.ncols)()
if hasattr(data, "internal") and hasattr(data.internal, "column"):
# datatable>0.8.0
for icol in range(data.ncols):
col = data.internal.column(icol)
ptr = col.data_pointer
ptrs[icol] = ctypes.c_void_p(ptr)
else:
# datatable<=0.8.0
from datatable.internal import \
frame_column_data_r # pylint: disable=no-name-in-module,import-error
for icol in range(data.ncols):
ptrs[icol] = frame_column_data_r(data, icol)
# always return stypes for dt ingestion
feature_type_strings = (ctypes.c_char_p * data.ncols)()
for icol in range(data.ncols):
feature_type_strings[icol] = ctypes.c_char_p(
data.stypes[icol].name.encode('utf-8'))
self._warn_unused_missing(data)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromDT(
ptrs, feature_type_strings,
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.byref(handle),
ctypes.c_int(self.nthread)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler('datatable', 'Frame', DTHandler)
__dmatrix_registry.register_handler('datatable', 'DataTable', DTHandler)
class CudaArrayInterfaceHandler(DataHandler):
'''Handler of data with `__cuda_array_interface__` (cupy.ndarray).'''
def handle_input(self, data, feature_names, feature_types):
"""Initialize DMatrix from cupy ndarray."""
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(self.missing),
ctypes.c_int(self.nthread),
ctypes.byref(handle)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler('cupy.core.core', 'ndarray',
CudaArrayInterfaceHandler)
class CudaColumnarHandler(DataHandler):
'''Handler of CUDA based columnar data. (cudf.DataFrame)'''
def _maybe_cudf_dataframe(self, data, feature_names, feature_types):
"""Extract internal data from cudf.DataFrame for DMatrix data."""
if feature_names is None:
if lazy_isinstance(data, 'cudf.core.series', 'Series'):
feature_names = [data.name]
elif lazy_isinstance(
data.columns, 'cudf.core.multiindex', 'MultiIndex'):
feature_names = [
' '.join([str(x) for x in i])
for i in data.columns
]
else:
feature_names = data.columns.format()
if feature_types is None:
if lazy_isinstance(data, 'cudf.core.series', 'Series'):
dtypes = [data.dtype]
else:
dtypes = data.dtypes
feature_types = [PandasHandler.pandas_dtype_mapper[d.name]
for d in dtypes]
return data, feature_names, feature_types
def transform(self, data):
return self._maybe_cudf_dataframe(data, None, None)
def handle_input(self, data, feature_names, feature_types):
"""Initialize DMatrix from columnar memory format."""
data, feature_names, feature_types = self._maybe_cudf_dataframe(
data, feature_names, feature_types)
interfaces_str = _cudf_array_interfaces(data)
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(self.missing),
ctypes.c_int(self.nthread),
ctypes.byref(handle)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler('cudf.core.dataframe', 'DataFrame',
CudaColumnarHandler)
__dmatrix_registry.register_handler('cudf.core.series', 'Series',
CudaColumnarHandler)
class DLPackHandler(CudaArrayInterfaceHandler):
'''Handler of `dlpack`.'''
def _maybe_dlpack_data(self, data, feature_names, feature_types):
from cupy import fromDlpack # pylint: disable=E0401
data = fromDlpack(data)
return data, feature_names, feature_types
def transform(self, data):
return self._maybe_dlpack_data(data, None, None)
def handle_input(self, data, feature_names, feature_types):
data, feature_names, feature_types = self._maybe_dlpack_data(
data, feature_names, feature_types)
return super().handle_input(
data, feature_names, feature_types)
__dmatrix_registry.register_handler_opaque(
lambda x: 'PyCapsule' in str(type(x)) and "dltensor" in str(x),
DLPackHandler)
class DeviceQuantileDMatrixDataHandler(DataHandler): # pylint: disable=abstract-method
'''Base class of data handler for `DeviceQuantileDMatrix`.'''
def __init__(self, max_bin, missing, nthread, silent,
meta=None, meta_type=None):
self.max_bin = max_bin
super().__init__(missing, nthread, silent, meta, meta_type)
__device_quantile_dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name
def get_device_quantile_dmatrix_data_handler(
data, max_bin, missing, nthread, silent):
'''Get data handler for `DeviceQuantileDMatrix`. Similar to
`get_dmatrix_data_handler`.
.. versionadded:: 1.2.0
'''
handler = __device_quantile_dmatrix_registry.get_handler(
data)
assert handler, 'Current data type ' + str(type(data)) +\
' is not supported for DeviceQuantileDMatrix'
return handler(max_bin, missing, nthread, silent)
class DeviceQuantileCudaArrayInterfaceHandler(
DeviceQuantileDMatrixDataHandler):
'''Handler of data with `__cuda_array_interface__`, for
`DeviceQuantileDMatrix`.
'''
def handle_input(self, data, feature_names, feature_types):
"""Initialize DMatrix from cupy ndarray."""
if not hasattr(data, '__cuda_array_interface__') and hasattr(
data, '__array__'):
import cupy # pylint: disable=import-error
data = cupy.array(data, copy=False)
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(self.missing), ctypes.c_int(self.nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
return handle, feature_names, feature_types
__device_quantile_dmatrix_registry.register_handler(
'cupy.core.core', 'ndarray', DeviceQuantileCudaArrayInterfaceHandler)
__device_quantile_dmatrix_registry.register_handler_opaque(
lambda x: hasattr(x, '__array__'), NumpyHandler)
__device_quantile_dmatrix_registry.register_handler_opaque(
lambda x: hasattr(x, '__cuda_array_interface__'), NumpyHandler)
class DeviceQuantileCudaColumnarHandler(DeviceQuantileDMatrixDataHandler,
CudaColumnarHandler):
'''Handler of CUDA based columnar data, for `DeviceQuantileDMatrix`.'''
def __init__(self, max_bin, missing, nthread, silent,
meta=None, meta_type=None):
super().__init__(
max_bin=max_bin, missing=missing, nthread=nthread, silent=silent,
meta=meta, meta_type=meta_type
)
def handle_input(self, data, feature_names, feature_types):
"""Initialize Quantile Device DMatrix from columnar memory format."""
data, feature_names, feature_types = self._maybe_cudf_dataframe(
data, feature_names, feature_types)
interfaces_str = _cudf_array_interfaces(data)
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(self.missing), ctypes.c_int(self.nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
return handle, feature_names, feature_types
__device_quantile_dmatrix_registry.register_handler(
'cudf.core.dataframe', 'DataFrame', DeviceQuantileCudaColumnarHandler)
__device_quantile_dmatrix_registry.register_handler(
'cudf.core.series', 'Series', DeviceQuantileCudaColumnarHandler)
class DeviceQuantileDLPackHandler(DeviceQuantileCudaArrayInterfaceHandler,
DLPackHandler):
'''Handler of `dlpack`, for `DeviceQuantileDMatrix`.'''
def __init__(self, max_bin, missing, nthread, silent,
meta=None, meta_type=None):
super().__init__(
max_bin=max_bin, missing=missing, nthread=nthread, silent=silent,
meta=meta, meta_type=meta_type
)
def handle_input(self, data, feature_names, feature_types):
data, feature_names, feature_types = self._maybe_dlpack_data(
data, feature_names, feature_types)
return super().handle_input(
data, feature_names, feature_types)
__device_quantile_dmatrix_registry.register_handler_opaque(
lambda x: 'PyCapsule' in str(type(x)) and "dltensor" in str(x),
DeviceQuantileDLPackHandler)

View File

@ -246,7 +246,7 @@ class XGBModel(XGBModelBase):
def _more_tags(self): def _more_tags(self):
'''Tags used for scikit-learn data validation.''' '''Tags used for scikit-learn data validation.'''
return {'allow_nan': True} return {'allow_nan': True, 'no_validation': True}
def get_booster(self): def get_booster(self):
"""Get the underlying xgboost Booster of this model. """Get the underlying xgboost Booster of this model.
@ -258,7 +258,8 @@ class XGBModel(XGBModelBase):
booster : a xgboost booster of underlying model booster : a xgboost booster of underlying model
""" """
if not hasattr(self, '_Booster'): if not hasattr(self, '_Booster'):
raise XGBoostError('need to call fit or load_model beforehand') from sklearn.exceptions import NotFittedError
raise NotFittedError('need to call fit or load_model beforehand')
return self._Booster return self._Booster
def set_params(self, **params): def set_params(self, **params):
@ -332,7 +333,7 @@ class XGBModel(XGBModelBase):
for k, v in internal.items(): for k, v in internal.items():
if k in params.keys() and params[k] is None: if k in params.keys() and params[k] is None:
params[k] = parse_parameter(v) params[k] = parse_parameter(v)
except XGBoostError: except ValueError:
pass pass
return params return params
@ -536,12 +537,16 @@ class XGBModel(XGBModelBase):
else: else:
params.update({'eval_metric': eval_metric}) params.update({'eval_metric': eval_metric})
try:
self._Booster = train(params, train_dmatrix, self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(), evals=evals, self.get_num_boosting_rounds(), evals=evals,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval, evals_result=evals_result,
obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model, verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks) callbacks=callbacks)
except XGBoostError as e:
raise ValueError(e)
if evals_result: if evals_result:
for val in evals_result.items(): for val in evals_result.items():
@ -1225,6 +1230,7 @@ class XGBRanker(XGBModel):
'Custom evaluation metric is not yet supported for XGBRanker.') 'Custom evaluation metric is not yet supported for XGBRanker.')
params.update({'eval_metric': eval_metric}) params.update({'eval_metric': eval_metric})
try:
self._Booster = train(params, train_dmatrix, self._Booster = train(params, train_dmatrix,
self.n_estimators, self.n_estimators,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
@ -1232,6 +1238,8 @@ class XGBRanker(XGBModel):
evals_result=evals_result, feval=feval, evals_result=evals_result, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model, verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks) callbacks=callbacks)
except XGBoostError as e:
raise ValueError(e)
self.objective = params["objective"] self.objective = params["objective"]

View File

@ -22,6 +22,16 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
dtrain = DMatrixT(X, missing=missing, label=y) dtrain = DMatrixT(X, missing=missing, label=y)
assert dtrain.num_col() == kCols assert dtrain.num_col() == kCols
assert dtrain.num_row() == kRows assert dtrain.num_row() == kRows
if DMatrixT is xgb.DeviceQuantileDMatrix:
# Slice is not supported by DeviceQuantileDMatrix
with pytest.raises(xgb.core.XGBoostError):
dtrain.slice(rindex=[0, 1, 2])
dtrain.slice(rindex=[0, 1, 2])
else:
dtrain.slice(rindex=[0, 1, 2])
dtrain.slice(rindex=[0, 1, 2])
return dtrain return dtrain
@ -41,7 +51,7 @@ def _test_from_cupy(DMatrixT):
with pytest.raises(Exception): with pytest.raises(Exception):
X = cp.random.randn(2, 2, dtype="float32") X = cp.random.randn(2, 2, dtype="float32")
dtrain = DMatrixT(X, label=X) DMatrixT(X, label=X)
def _test_cupy_training(DMatrixT): def _test_cupy_training(DMatrixT):
@ -88,11 +98,14 @@ def _test_cupy_metainfo(DMatrixT):
dmat_cupy.set_interface_info('group', cupy_uints) dmat_cupy.set_interface_info('group', cupy_uints)
# Test setting info with cupy # Test setting info with cupy
assert np.array_equal(dmat.get_float_info('weight'), dmat_cupy.get_float_info('weight')) assert np.array_equal(dmat.get_float_info('weight'),
assert np.array_equal(dmat.get_float_info('label'), dmat_cupy.get_float_info('label')) dmat_cupy.get_float_info('weight'))
assert np.array_equal(dmat.get_float_info('label'),
dmat_cupy.get_float_info('label'))
assert np.array_equal(dmat.get_float_info('base_margin'), assert np.array_equal(dmat.get_float_info('base_margin'),
dmat_cupy.get_float_info('base_margin')) dmat_cupy.get_float_info('base_margin'))
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr')) assert np.array_equal(dmat.get_uint_info('group_ptr'),
dmat_cupy.get_uint_info('group_ptr'))
class TestFromCupy: class TestFromCupy:
@ -135,7 +148,9 @@ Arrow specification.'''
import cupy as cp import cupy as cp
n = 100 n = 100
X = cp.random.random((n, 2)) X = cp.random.random((n, 2))
xgb.DeviceQuantileDMatrix(X.toDlpack()) m = xgb.DeviceQuantileDMatrix(X.toDlpack())
with pytest.raises(xgb.core.XGBoostError):
m.slice(rindex=[0, 1, 2])
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu @pytest.mark.mgpu

View File

@ -67,7 +67,8 @@ class TestPandas(unittest.TestCase):
# 0 1 1 0 0 # 0 1 1 0 0
# 1 2 0 1 0 # 1 2 0 1 0
# 2 3 0 0 1 # 2 3 0 0 1
result, _, _ = xgb.core._maybe_pandas_data(dummies, None, None) pandas_handler = xgb.data.PandasHandler(np.nan, 0, False)
result, _, _ = pandas_handler._maybe_pandas_data(dummies, None, None)
exp = np.array([[1., 1., 0., 0.], exp = np.array([[1., 1., 0., 0.],
[2., 0., 1., 0.], [2., 0., 1., 0.],
[3., 0., 0., 1.]]) [3., 0., 0., 1.]])
@ -113,12 +114,12 @@ class TestPandas(unittest.TestCase):
import pandas as pd import pandas as pd
rows = 100 rows = 100
X = pd.DataFrame( X = pd.DataFrame(
{"A": pd.SparseArray(np.random.randint(0, 10, size=rows)), {"A": pd.arrays.SparseArray(np.random.randint(0, 10, size=rows)),
"B": pd.SparseArray(np.random.randn(rows)), "B": pd.arrays.SparseArray(np.random.randn(rows)),
"C": pd.SparseArray(np.random.permutation( "C": pd.arrays.SparseArray(np.random.permutation(
[True, False] * (rows // 2)))} [True, False] * (rows // 2)))}
) )
y = pd.Series(pd.SparseArray(np.random.randn(rows))) y = pd.Series(pd.arrays.SparseArray(np.random.randn(rows)))
dtrain = xgb.DMatrix(X, y) dtrain = xgb.DMatrix(X, y)
booster = xgb.train({}, dtrain, num_boost_round=4) booster = xgb.train({}, dtrain, num_boost_round=4)
predt_sparse = booster.predict(xgb.DMatrix(X)) predt_sparse = booster.predict(xgb.DMatrix(X))
@ -128,16 +129,17 @@ class TestPandas(unittest.TestCase):
def test_pandas_label(self): def test_pandas_label(self):
# label must be a single column # label must be a single column
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]}) df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df, pandas_handler = xgb.data.PandasHandler(np.nan, 0, False)
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
None, None, 'label', 'float') None, None, 'label', 'float')
# label must be supported dtype # label must be supported dtype
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)}) df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df, self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
None, None, 'label', 'float') None, None, 'label', 'float')
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)}) df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
result, _, _ = xgb.core._maybe_pandas_data(df, None, None, result, _, _ = pandas_handler._maybe_pandas_data(df, None, None,
'label', 'float') 'label', 'float')
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]], np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
dtype=float)) dtype=float))