Implement Python data handler. (#5689)

* Define data handlers for DMatrix.
* Throw ValueError in scikit learn interface.
This commit is contained in:
Jiaming Yuan 2020-05-22 11:53:55 +08:00 committed by GitHub
parent 646def51e0
commit 5af8161a1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 746 additions and 405 deletions

View File

@ -14,7 +14,7 @@ from sklearn.datasets import load_iris, load_digits, load_boston
rng = np.random.RandomState(31337)
print("Zeros and Ones from the Digits dataset: binary classification")
digits = load_digits(2)
digits = load_digits(n_class=2)
y = digits['target']
X = digits['data']
kf = KFold(n_splits=2, shuffle=True, random_state=rng)

View File

@ -107,7 +107,6 @@ except ImportError:
try:
from cudf import DataFrame as CUDF_DataFrame
from cudf import Series as CUDF_Series
from cudf import MultiIndex as CUDF_MultiIndex
from cudf import concat as CUDF_concat
CUDF_INSTALLED = True
except ImportError:

View File

@ -1,7 +1,6 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-branches, too-many-lines, too-many-locals
# pylint: disable=too-many-public-methods
# pylint: disable=too-many-lines, too-many-locals
"""Core XGBoost Library."""
import collections
# pylint: disable=no-name-in-module,import-error
@ -11,16 +10,15 @@ import ctypes
import os
import re
import sys
import warnings
import json
import numpy as np
import scipy.sparse
from .compat import (
STRING_TYPES, DataFrame, MultiIndex, Int64Index, py_str,
STRING_TYPES, DataFrame, py_str,
PANDAS_INSTALLED, CUDF_INSTALLED,
CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex,
CUDF_DataFrame,
os_fspath, os_PathLike, lazy_isinstance)
from .libpath import find_lib_path
@ -262,10 +260,24 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64':
'int', 'uint8': 'int', 'uint16': 'int', 'uint32': 'int',
'uint64': 'int', 'float16': 'float', 'float32': 'float',
'float64': 'float', 'bool': 'i'}
def _convert_unknown_data(data, meta=None, meta_type=None):
if meta is not None:
try:
data = np.array(data, dtype=meta_type)
except Exception:
raise TypeError('Can not handle data from {}'.format(
type(data).__name__))
else:
import warnings
warnings.warn(
'Unknown data type: ' + str(type(data)) +
', coverting it to csr_matrix')
try:
data = scipy.sparse.csr_matrix(data)
except Exception:
raise TypeError('Can not initialize DMatrix from'
' {}'.format(type(data).__name__))
return data
# Either object has cuda array interface or contains columns with interfaces
@ -274,57 +286,12 @@ def _has_cuda_array_interface(data):
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))
def _maybe_pandas_data(data, feature_names, feature_types,
meta=None, meta_type=None):
"""Extract internal data from pd.DataFrame for DMatrix data"""
if not (PANDAS_INSTALLED and isinstance(data, DataFrame)):
return data, feature_names, feature_types
from pandas.api.types import is_sparse
data_dtypes = data.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER or is_sparse(dtype)
for dtype in data_dtypes):
bad_fields = [
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
if dtype.name not in PANDAS_DTYPE_MAPPER
]
msg = """DataFrame.dtypes for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None and meta is None:
if isinstance(data.columns, MultiIndex):
feature_names = [
' '.join([str(x) for x in i]) for i in data.columns
]
elif isinstance(data.columns, Int64Index):
feature_names = list(map(str, data.columns))
else:
feature_names = data.columns.format()
if feature_types is None and meta is None:
feature_types = []
for dtype in data_dtypes:
if is_sparse(dtype):
feature_types.append(PANDAS_DTYPE_MAPPER[dtype.subtype.name])
else:
feature_types.append(PANDAS_DTYPE_MAPPER[dtype.name])
if meta and len(data.columns) > 1:
raise ValueError(
'DataFrame for {meta} cannot have multiple columns'.format(
meta=meta))
dtype = meta_type if meta_type else 'float'
data = data.values.astype(dtype)
return data, feature_names, feature_types
def _cudf_array_interfaces(df):
'''Extract CuDF __cuda_array_interface__'''
interfaces = []
if lazy_isinstance(df, 'cudf.core.series', 'Series'):
interfaces.append(df.__cuda_array_interface__)
else:
for col in df:
interface = df[col].__cuda_array_interface__
if 'mask' in interface:
@ -334,124 +301,7 @@ def _cudf_array_interfaces(df):
return interfaces_str
def _maybe_cudf_dataframe(data, feature_names, feature_types):
"""Extract internal data from cudf.DataFrame for DMatrix data."""
if not (CUDF_INSTALLED and isinstance(data,
(CUDF_DataFrame, CUDF_Series))):
return data, feature_names, feature_types
if feature_names is None:
if isinstance(data, CUDF_Series):
feature_names = [data.name]
elif isinstance(data.columns, CUDF_MultiIndex):
feature_names = [
' '.join([str(x) for x in i])
for i in data.columns
]
else:
feature_names = data.columns.format()
if feature_types is None:
if isinstance(data, CUDF_Series):
dtypes = [data.dtype]
else:
dtypes = data.dtypes
feature_types = [PANDAS_DTYPE_MAPPER[d.name] for d in dtypes]
return data, feature_names, feature_types
DT_TYPE_MAPPER = {'bool': 'bool', 'int': 'int', 'real': 'float'}
DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
def _maybe_dt_data(data, feature_names, feature_types,
meta=None, meta_type=None):
"""Validate feature names and types if data table"""
if (not lazy_isinstance(data, 'datatable', 'Frame') and
not lazy_isinstance(data, 'datatable', 'DataTable')):
return data, feature_names, feature_types
if meta and data.shape[1] > 1:
raise ValueError(
'DataTable for label or weight cannot have multiple columns')
if meta:
# below requires new dt version
# extract first column
data = data.to_numpy()[:, 0].astype(meta_type)
return data, None, None
data_types_names = tuple(lt.name for lt in data.ltypes)
bad_fields = [data.names[i]
for i, type_name in enumerate(data_types_names)
if type_name not in DT_TYPE_MAPPER]
if bad_fields:
msg = """DataFrame.types for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None and meta is None:
feature_names = data.names
# always return stypes for dt ingestion
if feature_types is not None:
raise ValueError(
'DataTable has own feature types, cannot pass them in.')
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
return data, feature_names, feature_types
def _is_dlpack(x):
return 'PyCapsule' in str(type(x)) and "dltensor" in str(x)
# Just convert dlpack into cupy (zero copy)
def _maybe_dlpack_data(data, feature_names, feature_types):
if not _is_dlpack(data):
return data, feature_names, feature_types
from cupy import fromDlpack # pylint: disable=E0401
data = fromDlpack(data)
return data, feature_names, feature_types
def _convert_dataframes(data, feature_names, feature_types,
meta=None, meta_type=None):
data, feature_names, feature_types = _maybe_pandas_data(data,
feature_names,
feature_types,
meta,
meta_type)
data, feature_names, feature_types = _maybe_dt_data(data,
feature_names,
feature_types,
meta,
meta_type)
data, feature_names, feature_types = _maybe_cudf_dataframe(
data, feature_names, feature_types)
data, feature_names, feature_types = _maybe_dlpack_data(
data, feature_names, feature_types)
return data, feature_names, feature_types
def _maybe_np_slice(data, dtype=np.float32):
'''Handle numpy slice. This can be removed if we use __array_interface__.
'''
try:
if not data.flags.c_contiguous:
warnings.warn(
"Use subset (sliced data) of np.ndarray is not recommended " +
"because it will generate extra copies and increase " +
"memory consumption")
data = np.array(data, copy=True, dtype=dtype)
else:
data = np.array(data, copy=False, dtype=dtype)
except AttributeError:
data = np.array(data, copy=False, dtype=dtype)
return data
class DMatrix(object):
class DMatrix: # pylint: disable=too-many-instance-attributes
"""Data Matrix used in XGBoost.
DMatrix is a internal data structure that used by XGBoost
@ -503,6 +353,13 @@ class DMatrix(object):
applicable. If -1, uses maximum threads available on the system.
"""
if isinstance(data, list):
raise TypeError('Input data can not be a list.')
self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else 1
self.silent = silent
# force into void_p, mac need to pass things in as void_p
if data is None:
self.handle = None
@ -513,40 +370,13 @@ class DMatrix(object):
self._feature_types = feature_types
return
if isinstance(data, list):
raise TypeError('Input data can not be a list.')
data, feature_names, feature_types = _convert_dataframes(
data, feature_names, feature_types
)
missing = missing if missing is not None else np.nan
nthread = nthread if nthread is not None else 1
if isinstance(data, (STRING_TYPES, os_PathLike)):
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
ctypes.c_int(silent),
ctypes.byref(handle)))
self.handle = handle
elif isinstance(data, scipy.sparse.csr_matrix):
self._init_from_csr(data)
elif isinstance(data, scipy.sparse.csc_matrix):
self._init_from_csc(data)
elif isinstance(data, np.ndarray):
self._init_from_npy2d(data, missing, nthread)
elif lazy_isinstance(data, 'datatable', 'Frame'):
self._init_from_dt(data, nthread)
elif hasattr(data, "__cuda_array_interface__"):
self._init_from_array_interface(data, missing, nthread)
elif CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
self._init_from_array_interface_columns(data, missing, nthread)
else:
try:
csr = scipy.sparse.csr_matrix(data)
self._init_from_csr(csr)
except Exception:
raise TypeError('can not initialize DMatrix from'
' {}'.format(type(data).__name__))
handler = self.get_data_handler(data)
if handler is None:
data = _convert_unknown_data(data, None)
handler = self.get_data_handler(data)
self.handle, feature_names, feature_types = handler.handle_input(
data, feature_names, feature_types)
assert self.handle, 'Failed to construct a DMatrix.'
if label is not None:
self.set_label(label)
@ -558,126 +388,12 @@ class DMatrix(object):
self.feature_names = feature_names
self.feature_types = feature_types
def _init_from_csr(self, csr):
"""Initialize data from a CSR matrix."""
if len(csr.indices) != len(csr.data):
raise ValueError('length mismatch: {} vs {}'.format(
len(csr.indices), len(csr.data)))
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSREx(
c_array(ctypes.c_size_t, csr.indptr),
c_array(ctypes.c_uint, csr.indices),
c_array(ctypes.c_float, csr.data),
ctypes.c_size_t(len(csr.indptr)),
ctypes.c_size_t(len(csr.data)),
ctypes.c_size_t(csr.shape[1]),
ctypes.byref(handle)))
self.handle = handle
def _init_from_csc(self, csc):
"""Initialize data from a CSC matrix."""
if len(csc.indices) != len(csc.data):
raise ValueError('length mismatch: {} vs {}'.format(
len(csc.indices), len(csc.data)))
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSCEx(
c_array(ctypes.c_size_t, csc.indptr),
c_array(ctypes.c_uint, csc.indices),
c_array(ctypes.c_float, csc.data),
ctypes.c_size_t(len(csc.indptr)),
ctypes.c_size_t(len(csc.data)),
ctypes.c_size_t(csc.shape[0]),
ctypes.byref(handle)))
self.handle = handle
def _init_from_npy2d(self, mat, missing, nthread):
"""Initialize data from a 2-D numpy matrix.
If ``mat`` does not have ``order='C'`` (aka row-major) or is
not contiguous, a temporary copy will be made.
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will
be made.
So there could be as many as two temporary data copies; be mindful of
input layout and type if memory use is a concern.
"""
if len(mat.shape) != 2:
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
mat.shape)
# flatten the array by rows and ensure it is float32. we try to avoid
# data copies if possible (reshape returns a view when possible and we
# explicitly tell np.array to try and avoid copying)
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(mat.shape[0]),
c_bst_ulong(mat.shape[1]),
ctypes.c_float(missing),
ctypes.byref(handle),
ctypes.c_int(nthread)))
self.handle = handle
def _init_from_dt(self, data, nthread):
"""Initialize data from a datatable Frame."""
ptrs = (ctypes.c_void_p * data.ncols)()
if hasattr(data, "internal") and hasattr(data.internal, "column"):
# datatable>0.8.0
for icol in range(data.ncols):
col = data.internal.column(icol)
ptr = col.data_pointer
ptrs[icol] = ctypes.c_void_p(ptr)
else:
# datatable<=0.8.0
from datatable.internal import \
frame_column_data_r # pylint: disable=no-name-in-module,import-error
for icol in range(data.ncols):
ptrs[icol] = frame_column_data_r(data, icol)
# always return stypes for dt ingestion
feature_type_strings = (ctypes.c_char_p * data.ncols)()
for icol in range(data.ncols):
feature_type_strings[icol] = ctypes.c_char_p(
data.stypes[icol].name.encode('utf-8'))
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromDT(
ptrs, feature_type_strings,
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.byref(handle),
ctypes.c_int(nthread)))
self.handle = handle
def _init_from_array_interface_columns(self, df, missing, nthread):
"""Initialize DMatrix from columnar memory format."""
interfaces_str = _cudf_array_interfaces(df)
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(missing),
ctypes.c_int(nthread),
ctypes.byref(handle)))
self.handle = handle
def _init_from_array_interface(self, data, missing, nthread):
"""Initialize DMatrix from cupy ndarray."""
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(missing),
ctypes.c_int(nthread),
ctypes.byref(handle)))
self.handle = handle
def get_data_handler(self, data, meta=None, meta_type=None):
'''Get data handler for this DMatrix class.'''
from .data import get_dmatrix_data_handler
handler = get_dmatrix_data_handler(
data, self.missing, self.nthread, self.silent, meta, meta_type)
return handler
def __del__(self):
if hasattr(self, "handle") and self.handle:
@ -737,10 +453,14 @@ class DMatrix(object):
data: numpy array
The array of data to be set
"""
data, _, _ = _convert_dataframes(data, None, None, field, 'float')
if isinstance(data, np.ndarray):
self.set_float_info_npy2d(field, data)
return
handler = self.get_data_handler(data, field, np.float32)
if handler is None:
data = _convert_unknown_data(data, field, np.float32)
handler = self.get_data_handler(data, field, np.float32)
data, _, _ = handler.transform(data)
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
@ -759,7 +479,8 @@ class DMatrix(object):
data: numpy array
The array of data to be set
"""
data = _maybe_np_slice(data, np.float32)
data, _, _ = self.get_data_handler(
data, field, np.float32).transform(data)
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
@ -777,9 +498,8 @@ class DMatrix(object):
data: numpy array
The array of data to be set
"""
data = _maybe_np_slice(data, np.uint32)
data, _, _ = _convert_dataframes(data, None, None, field, 'uint32')
data = np.array(data, copy=False, dtype=ctypes.c_uint)
data, _, _ = self.get_data_handler(
data, field, 'uint32').transform(data)
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
c_str(field),
c_array(ctypes.c_uint, data),
@ -1075,46 +795,18 @@ class DeviceQuantileDMatrix(DMatrix):
feature_types=None,
nthread=None, max_bin=256):
self.max_bin = max_bin
if not (hasattr(data, "__cuda_array_interface__") or (
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame)) or _is_dlpack(data)):
raise ValueError('Only cupy/cudf/dlpack currently supported for DeviceQuantileDMatrix')
super().__init__(data, label=label, weight=weight, base_margin=base_margin,
super().__init__(data, label=label, weight=weight,
base_margin=base_margin,
missing=missing,
silent=silent,
feature_names=feature_names,
feature_types=feature_types,
nthread=nthread)
def _init_from_array_interface_columns(self, df, missing, nthread):
"""Initialize DMatrix from columnar memory format."""
interfaces_str = _cudf_array_interfaces(df)
handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan
nthread = nthread if nthread is not None else 1
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(missing), ctypes.c_int(nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
self.handle = handle
def _init_from_array_interface(self, data, missing, nthread):
"""Initialize DMatrix from cupy ndarray."""
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan
nthread = nthread if nthread is not None else 1
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(missing), ctypes.c_int(nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
self.handle = handle
def get_data_handler(self, data, meta=None, meta_type=None):
from .data import get_device_quantile_dmatrix_data_handler
return get_device_quantile_dmatrix_data_handler(
data, self.max_bin, self.missing, self.nthread, self.silent)
class Booster(object):
@ -1467,6 +1159,7 @@ class Booster(object):
self._validate_features(data)
return self.eval_set([(data, name)], iteration)
# pylint: disable=too-many-function-args
def predict(self,
data,
output_margin=False,

View File

@ -0,0 +1,624 @@
# pylint: disable=too-many-arguments, no-self-use
'''Data dispatching for DMatrix.'''
import ctypes
import abc
import json
import warnings
import numpy as np
from .core import c_array, _LIB, _check_call, c_str, _cudf_array_interfaces
from .compat import lazy_isinstance, STRING_TYPES, os_fspath, os_PathLike
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
class DataHandler(abc.ABC):
'''Base class for various data handler.'''
def __init__(self, missing, nthread, silent, meta=None, meta_type=None):
self.missing = missing
self.nthread = nthread
self.silent = silent
self.meta = meta
self.meta_type = meta_type
def _warn_unused_missing(self, data):
if not (np.isnan(np.nan) or None):
warnings.warn(
'`missing` is not used for current input data type:' +
str(type(data)))
def check_complex(self, data):
'''Test whether data is complex using `dtype` attribute.'''
complex_dtypes = (np.complex128, np.complex64,
np.cfloat, np.cdouble, np.clongdouble)
if hasattr(data, 'dtype') and data.dtype in complex_dtypes:
raise ValueError('Complex data not supported')
def transform(self, data):
'''Optional method for transforming data before being accepted by
other XGBoost API.'''
return data, None, None
@abc.abstractmethod
def handle_input(self, data, feature_names, feature_types):
'''Abstract method for handling different data input.'''
class DMatrixDataManager:
'''The registry class for various data handler.'''
def __init__(self):
self.__data_handlers = {}
self.__data_handlers_dly = []
def register_handler(self, module, name, handler):
'''Register a data handler handling specfic type of data.'''
self.__data_handlers['.'.join([module, name])] = handler
def register_handler_opaque(self, func, handler):
'''Register a data handler that handles data with opaque type.
Parameters
----------
func : callable
A function with a single parameter `data`. It should return True
if the handler can handle this data, otherwise returns False.
handler : xgboost.data.DataHandler
The handler class that is a subclass of `DataHandler`.
'''
self.__data_handlers_dly.append((func, handler))
def get_handler(self, data):
'''Get a handler of `data`, returns None if handler not found.'''
module, name = type(data).__module__, type(data).__name__
if '.'.join([module, name]) in self.__data_handlers.keys():
handler = self.__data_handlers['.'.join([module, name])]
return handler
for f, handler in self.__data_handlers_dly:
if f(data):
return handler
return None
__dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name
def get_dmatrix_data_handler(data, missing, nthread, silent,
meta=None, meta_type=None):
'''Get a handler of `data` for DMatrix.
.. versionadded:: 1.2.0
Parameters
----------
data : any
The input data.
missing : float
Same as `missing` for DMatrix.
nthread : int
Same as `nthread` for DMatrix.
silent : boolean
Same as `silent` for DMatrix.
meta : str
Field name of meta data, like `label`. Used only for getting handler
for meta info.
meta_type : str/np.dtype
Type of meta data.
Returns
-------
handler : DataHandler
'''
handler = __dmatrix_registry.get_handler(data)
if handler is None:
return None
return handler(missing, nthread, silent, meta, meta_type)
class FileHandler(DataHandler):
'''Handler of path like input.'''
def handle_input(self, data, feature_names, feature_types):
self._warn_unused_missing(data)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
ctypes.c_int(self.silent),
ctypes.byref(handle)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler_opaque(
lambda data: isinstance(data, (STRING_TYPES, os_PathLike)),
FileHandler)
class CSRHandler(DataHandler):
'''Handler of `scipy.sparse.csr.csr_matrix`.'''
def handle_input(self, data, feature_names, feature_types):
'''Initialize data from a CSR matrix.'''
if len(data.indices) != len(data.data):
raise ValueError('length mismatch: {} vs {}'.format(
len(data.indices), len(data.data)))
self._warn_unused_missing(data)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSREx(
c_array(ctypes.c_size_t, data.indptr),
c_array(ctypes.c_uint, data.indices),
c_array(ctypes.c_float, data.data),
ctypes.c_size_t(len(data.indptr)),
ctypes.c_size_t(len(data.data)),
ctypes.c_size_t(data.shape[1]),
ctypes.byref(handle)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler(
'scipy.sparse.csr', 'csr_matrix', CSRHandler)
class CSCHandler(DataHandler):
'''Handler of `scipy.sparse.csc.csc_matrix`.'''
def handle_input(self, data, feature_names, feature_types):
if len(data.indices) != len(data.data):
raise ValueError('length mismatch: {} vs {}'.format(
len(data.indices), len(data.data)))
self._warn_unused_missing(data)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromCSCEx(
c_array(ctypes.c_size_t, data.indptr),
c_array(ctypes.c_uint, data.indices),
c_array(ctypes.c_float, data.data),
ctypes.c_size_t(len(data.indptr)),
ctypes.c_size_t(len(data.data)),
ctypes.c_size_t(data.shape[0]),
ctypes.byref(handle)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler(
'scipy.sparse.csc', 'csc_matrix', CSCHandler)
class NumpyHandler(DataHandler):
'''Handler of `numpy.ndarray`.'''
def _maybe_np_slice(self, data, dtype):
'''Handle numpy slice. This can be removed if we use __array_interface__.
'''
try:
if not data.flags.c_contiguous:
warnings.warn(
"Use subset (sliced data) of np.ndarray is not recommended " +
"because it will generate extra copies and increase " +
"memory consumption")
data = np.array(data, copy=True, dtype=dtype)
else:
data = np.array(data, copy=False, dtype=dtype)
except AttributeError:
data = np.array(data, copy=False, dtype=dtype)
return data
def transform(self, data):
return self._maybe_np_slice(data, self.meta_type), None, None
def handle_input(self, data, feature_names, feature_types):
"""Initialize data from a 2-D numpy matrix.
If ``mat`` does not have ``order='C'`` (aka row-major) or is
not contiguous, a temporary copy will be made.
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will
be made.
So there could be as many as two temporary data copies; be mindful of
input layout and type if memory use is a concern.
"""
if not isinstance(data, np.ndarray) and hasattr(data, '__array__'):
data = np.array(data, copy=False)
if len(data.shape) != 2:
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
data.shape)
# flatten the array by rows and ensure it is float32. we try to avoid
# data copies if possible (reshape returns a view when possible and we
# explicitly tell np.array to try and avoid copying)
flatten = np.array(data.reshape(data.size), copy=False,
dtype=np.float32)
flatten = self._maybe_np_slice(flatten, np.float32)
self.check_complex(data)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.c_float(self.missing),
ctypes.byref(handle),
ctypes.c_int(self.nthread)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler('numpy', 'ndarray', NumpyHandler)
__dmatrix_registry.register_handler('numpy', 'matrix', NumpyHandler)
__dmatrix_registry.register_handler_opaque(
lambda x: hasattr(x, '__array__'), NumpyHandler)
class ListHandler(NumpyHandler):
'''Handler of builtin list and tuple'''
def handle_input(self, data, feature_names, feature_types):
assert self.meta is None, 'List input data is not supported for X'
data = np.array(data)
return super().handle_input(data, feature_names, feature_types)
__dmatrix_registry.register_handler('builtins', 'list', NumpyHandler)
__dmatrix_registry.register_handler('builtins', 'tuple', NumpyHandler)
class PandasHandler(NumpyHandler):
'''Handler of data structures defined by `pandas`.'''
pandas_dtype_mapper = {
'int8': 'int',
'int16': 'int',
'int32': 'int',
'int64': 'int',
'uint8': 'int',
'uint16': 'int',
'uint32': 'int',
'uint64': 'int',
'float16': 'float',
'float32': 'float',
'float64': 'float',
'bool': 'i'
}
def _maybe_pandas_data(self, data, feature_names, feature_types,
meta=None, meta_type=None):
"""Extract internal data from pd.DataFrame for DMatrix data"""
if lazy_isinstance(data, 'pandas.core.series', 'Series'):
dtype = meta_type if meta_type else 'float'
return data.values.astype(dtype), feature_names, feature_types
from pandas.api.types import is_sparse
from pandas import MultiIndex, Int64Index
data_dtypes = data.dtypes
if not all(dtype.name in self.pandas_dtype_mapper or is_sparse(dtype)
for dtype in data_dtypes):
bad_fields = [
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
if dtype.name not in self.pandas_dtype_mapper
]
msg = """DataFrame.dtypes for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None and meta is None:
if isinstance(data.columns, MultiIndex):
feature_names = [
' '.join([str(x) for x in i]) for i in data.columns
]
elif isinstance(data.columns, Int64Index):
feature_names = list(map(str, data.columns))
else:
feature_names = data.columns.format()
if feature_types is None and meta is None:
feature_types = []
for dtype in data_dtypes:
if is_sparse(dtype):
feature_types.append(self.pandas_dtype_mapper[
dtype.subtype.name])
else:
feature_types.append(self.pandas_dtype_mapper[dtype.name])
if meta and len(data.columns) > 1:
raise ValueError(
'DataFrame for {meta} cannot have multiple columns'.format(
meta=meta))
dtype = meta_type if meta_type else 'float'
data = data.values.astype(dtype)
return data, feature_names, feature_types
def transform(self, data):
return self._maybe_pandas_data(data, None, None, self.meta,
self.meta_type)
def handle_input(self, data, feature_names, feature_types):
data, feature_names, feature_types = self._maybe_pandas_data(
data, feature_names, feature_types, self.meta, self.meta_type)
return super().handle_input(data, feature_names, feature_types)
__dmatrix_registry.register_handler(
'pandas.core.frame', 'DataFrame', PandasHandler)
__dmatrix_registry.register_handler(
'pandas.core.series', 'Series', PandasHandler)
class DTHandler(DataHandler):
'''Handler of datatable.'''
dt_type_mapper = {'bool': 'bool', 'int': 'int', 'real': 'float'}
dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
def _maybe_dt_data(self, data, feature_names, feature_types,
meta=None, meta_type=None):
"""Validate feature names and types if data table"""
if meta and data.shape[1] > 1:
raise ValueError(
'DataTable for label or weight cannot have multiple columns')
if meta:
# below requires new dt version
# extract first column
data = data.to_numpy()[:, 0].astype(meta_type)
return data, None, None
data_types_names = tuple(lt.name for lt in data.ltypes)
bad_fields = [data.names[i]
for i, type_name in enumerate(data_types_names)
if type_name not in self.dt_type_mapper]
if bad_fields:
msg = """DataFrame.types for data must be int, float or bool.
Did not expect the data types in fields """
raise ValueError(msg + ', '.join(bad_fields))
if feature_names is None and meta is None:
feature_names = data.names
# always return stypes for dt ingestion
if feature_types is not None:
raise ValueError(
'DataTable has own feature types, cannot pass them in.')
feature_types = np.vectorize(self.dt_type_mapper2.get)(
data_types_names)
return data, feature_names, feature_types
def transform(self, data):
return self._maybe_dt_data(data, None, None, self.meta, self.meta_type)
def handle_input(self, data, feature_names, feature_types):
data, feature_names, feature_types = self._maybe_dt_data(
data, feature_names, feature_types, self.meta, self.meta_type)
ptrs = (ctypes.c_void_p * data.ncols)()
if hasattr(data, "internal") and hasattr(data.internal, "column"):
# datatable>0.8.0
for icol in range(data.ncols):
col = data.internal.column(icol)
ptr = col.data_pointer
ptrs[icol] = ctypes.c_void_p(ptr)
else:
# datatable<=0.8.0
from datatable.internal import \
frame_column_data_r # pylint: disable=no-name-in-module,import-error
for icol in range(data.ncols):
ptrs[icol] = frame_column_data_r(data, icol)
# always return stypes for dt ingestion
feature_type_strings = (ctypes.c_char_p * data.ncols)()
for icol in range(data.ncols):
feature_type_strings[icol] = ctypes.c_char_p(
data.stypes[icol].name.encode('utf-8'))
self._warn_unused_missing(data)
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromDT(
ptrs, feature_type_strings,
c_bst_ulong(data.shape[0]),
c_bst_ulong(data.shape[1]),
ctypes.byref(handle),
ctypes.c_int(self.nthread)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler('datatable', 'Frame', DTHandler)
__dmatrix_registry.register_handler('datatable', 'DataTable', DTHandler)
class CudaArrayInterfaceHandler(DataHandler):
'''Handler of data with `__cuda_array_interface__` (cupy.ndarray).'''
def handle_input(self, data, feature_names, feature_types):
"""Initialize DMatrix from cupy ndarray."""
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(self.missing),
ctypes.c_int(self.nthread),
ctypes.byref(handle)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler('cupy.core.core', 'ndarray',
CudaArrayInterfaceHandler)
class CudaColumnarHandler(DataHandler):
'''Handler of CUDA based columnar data. (cudf.DataFrame)'''
def _maybe_cudf_dataframe(self, data, feature_names, feature_types):
"""Extract internal data from cudf.DataFrame for DMatrix data."""
if feature_names is None:
if lazy_isinstance(data, 'cudf.core.series', 'Series'):
feature_names = [data.name]
elif lazy_isinstance(
data.columns, 'cudf.core.multiindex', 'MultiIndex'):
feature_names = [
' '.join([str(x) for x in i])
for i in data.columns
]
else:
feature_names = data.columns.format()
if feature_types is None:
if lazy_isinstance(data, 'cudf.core.series', 'Series'):
dtypes = [data.dtype]
else:
dtypes = data.dtypes
feature_types = [PandasHandler.pandas_dtype_mapper[d.name]
for d in dtypes]
return data, feature_names, feature_types
def transform(self, data):
return self._maybe_cudf_dataframe(data, None, None)
def handle_input(self, data, feature_names, feature_types):
"""Initialize DMatrix from columnar memory format."""
data, feature_names, feature_types = self._maybe_cudf_dataframe(
data, feature_names, feature_types)
interfaces_str = _cudf_array_interfaces(data)
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(self.missing),
ctypes.c_int(self.nthread),
ctypes.byref(handle)))
return handle, feature_names, feature_types
__dmatrix_registry.register_handler('cudf.core.dataframe', 'DataFrame',
CudaColumnarHandler)
__dmatrix_registry.register_handler('cudf.core.series', 'Series',
CudaColumnarHandler)
class DLPackHandler(CudaArrayInterfaceHandler):
'''Handler of `dlpack`.'''
def _maybe_dlpack_data(self, data, feature_names, feature_types):
from cupy import fromDlpack # pylint: disable=E0401
data = fromDlpack(data)
return data, feature_names, feature_types
def transform(self, data):
return self._maybe_dlpack_data(data, None, None)
def handle_input(self, data, feature_names, feature_types):
data, feature_names, feature_types = self._maybe_dlpack_data(
data, feature_names, feature_types)
return super().handle_input(
data, feature_names, feature_types)
__dmatrix_registry.register_handler_opaque(
lambda x: 'PyCapsule' in str(type(x)) and "dltensor" in str(x),
DLPackHandler)
class DeviceQuantileDMatrixDataHandler(DataHandler): # pylint: disable=abstract-method
'''Base class of data handler for `DeviceQuantileDMatrix`.'''
def __init__(self, max_bin, missing, nthread, silent,
meta=None, meta_type=None):
self.max_bin = max_bin
super().__init__(missing, nthread, silent, meta, meta_type)
__device_quantile_dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name
def get_device_quantile_dmatrix_data_handler(
data, max_bin, missing, nthread, silent):
'''Get data handler for `DeviceQuantileDMatrix`. Similar to
`get_dmatrix_data_handler`.
.. versionadded:: 1.2.0
'''
handler = __device_quantile_dmatrix_registry.get_handler(
data)
assert handler, 'Current data type ' + str(type(data)) +\
' is not supported for DeviceQuantileDMatrix'
return handler(max_bin, missing, nthread, silent)
class DeviceQuantileCudaArrayInterfaceHandler(
DeviceQuantileDMatrixDataHandler):
'''Handler of data with `__cuda_array_interface__`, for
`DeviceQuantileDMatrix`.
'''
def handle_input(self, data, feature_names, feature_types):
"""Initialize DMatrix from cupy ndarray."""
if not hasattr(data, '__cuda_array_interface__') and hasattr(
data, '__array__'):
import cupy # pylint: disable=import-error
data = cupy.array(data, copy=False)
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(self.missing), ctypes.c_int(self.nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
return handle, feature_names, feature_types
__device_quantile_dmatrix_registry.register_handler(
'cupy.core.core', 'ndarray', DeviceQuantileCudaArrayInterfaceHandler)
__device_quantile_dmatrix_registry.register_handler_opaque(
lambda x: hasattr(x, '__array__'), NumpyHandler)
__device_quantile_dmatrix_registry.register_handler_opaque(
lambda x: hasattr(x, '__cuda_array_interface__'), NumpyHandler)
class DeviceQuantileCudaColumnarHandler(DeviceQuantileDMatrixDataHandler,
CudaColumnarHandler):
'''Handler of CUDA based columnar data, for `DeviceQuantileDMatrix`.'''
def __init__(self, max_bin, missing, nthread, silent,
meta=None, meta_type=None):
super().__init__(
max_bin=max_bin, missing=missing, nthread=nthread, silent=silent,
meta=meta, meta_type=meta_type
)
def handle_input(self, data, feature_names, feature_types):
"""Initialize Quantile Device DMatrix from columnar memory format."""
data, feature_names, feature_types = self._maybe_cudf_dataframe(
data, feature_names, feature_types)
interfaces_str = _cudf_array_interfaces(data)
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(self.missing), ctypes.c_int(self.nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
return handle, feature_names, feature_types
__device_quantile_dmatrix_registry.register_handler(
'cudf.core.dataframe', 'DataFrame', DeviceQuantileCudaColumnarHandler)
__device_quantile_dmatrix_registry.register_handler(
'cudf.core.series', 'Series', DeviceQuantileCudaColumnarHandler)
class DeviceQuantileDLPackHandler(DeviceQuantileCudaArrayInterfaceHandler,
DLPackHandler):
'''Handler of `dlpack`, for `DeviceQuantileDMatrix`.'''
def __init__(self, max_bin, missing, nthread, silent,
meta=None, meta_type=None):
super().__init__(
max_bin=max_bin, missing=missing, nthread=nthread, silent=silent,
meta=meta, meta_type=meta_type
)
def handle_input(self, data, feature_names, feature_types):
data, feature_names, feature_types = self._maybe_dlpack_data(
data, feature_names, feature_types)
return super().handle_input(
data, feature_names, feature_types)
__device_quantile_dmatrix_registry.register_handler_opaque(
lambda x: 'PyCapsule' in str(type(x)) and "dltensor" in str(x),
DeviceQuantileDLPackHandler)

View File

@ -246,7 +246,7 @@ class XGBModel(XGBModelBase):
def _more_tags(self):
'''Tags used for scikit-learn data validation.'''
return {'allow_nan': True}
return {'allow_nan': True, 'no_validation': True}
def get_booster(self):
"""Get the underlying xgboost Booster of this model.
@ -258,7 +258,8 @@ class XGBModel(XGBModelBase):
booster : a xgboost booster of underlying model
"""
if not hasattr(self, '_Booster'):
raise XGBoostError('need to call fit or load_model beforehand')
from sklearn.exceptions import NotFittedError
raise NotFittedError('need to call fit or load_model beforehand')
return self._Booster
def set_params(self, **params):
@ -332,7 +333,7 @@ class XGBModel(XGBModelBase):
for k, v in internal.items():
if k in params.keys() and params[k] is None:
params[k] = parse_parameter(v)
except XGBoostError:
except ValueError:
pass
return params
@ -536,12 +537,16 @@ class XGBModel(XGBModelBase):
else:
params.update({'eval_metric': eval_metric})
try:
self._Booster = train(params, train_dmatrix,
self.get_num_boosting_rounds(), evals=evals,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, obj=obj, feval=feval,
evals_result=evals_result,
obj=obj, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
except XGBoostError as e:
raise ValueError(e)
if evals_result:
for val in evals_result.items():
@ -1225,6 +1230,7 @@ class XGBRanker(XGBModel):
'Custom evaluation metric is not yet supported for XGBRanker.')
params.update({'eval_metric': eval_metric})
try:
self._Booster = train(params, train_dmatrix,
self.n_estimators,
early_stopping_rounds=early_stopping_rounds,
@ -1232,6 +1238,8 @@ class XGBRanker(XGBModel):
evals_result=evals_result, feval=feval,
verbose_eval=verbose, xgb_model=xgb_model,
callbacks=callbacks)
except XGBoostError as e:
raise ValueError(e)
self.objective = params["objective"]

View File

@ -22,6 +22,16 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
dtrain = DMatrixT(X, missing=missing, label=y)
assert dtrain.num_col() == kCols
assert dtrain.num_row() == kRows
if DMatrixT is xgb.DeviceQuantileDMatrix:
# Slice is not supported by DeviceQuantileDMatrix
with pytest.raises(xgb.core.XGBoostError):
dtrain.slice(rindex=[0, 1, 2])
dtrain.slice(rindex=[0, 1, 2])
else:
dtrain.slice(rindex=[0, 1, 2])
dtrain.slice(rindex=[0, 1, 2])
return dtrain
@ -41,7 +51,7 @@ def _test_from_cupy(DMatrixT):
with pytest.raises(Exception):
X = cp.random.randn(2, 2, dtype="float32")
dtrain = DMatrixT(X, label=X)
DMatrixT(X, label=X)
def _test_cupy_training(DMatrixT):
@ -88,11 +98,14 @@ def _test_cupy_metainfo(DMatrixT):
dmat_cupy.set_interface_info('group', cupy_uints)
# Test setting info with cupy
assert np.array_equal(dmat.get_float_info('weight'), dmat_cupy.get_float_info('weight'))
assert np.array_equal(dmat.get_float_info('label'), dmat_cupy.get_float_info('label'))
assert np.array_equal(dmat.get_float_info('weight'),
dmat_cupy.get_float_info('weight'))
assert np.array_equal(dmat.get_float_info('label'),
dmat_cupy.get_float_info('label'))
assert np.array_equal(dmat.get_float_info('base_margin'),
dmat_cupy.get_float_info('base_margin'))
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))
assert np.array_equal(dmat.get_uint_info('group_ptr'),
dmat_cupy.get_uint_info('group_ptr'))
class TestFromCupy:
@ -135,7 +148,9 @@ Arrow specification.'''
import cupy as cp
n = 100
X = cp.random.random((n, 2))
xgb.DeviceQuantileDMatrix(X.toDlpack())
m = xgb.DeviceQuantileDMatrix(X.toDlpack())
with pytest.raises(xgb.core.XGBoostError):
m.slice(rindex=[0, 1, 2])
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu

View File

@ -67,7 +67,8 @@ class TestPandas(unittest.TestCase):
# 0 1 1 0 0
# 1 2 0 1 0
# 2 3 0 0 1
result, _, _ = xgb.core._maybe_pandas_data(dummies, None, None)
pandas_handler = xgb.data.PandasHandler(np.nan, 0, False)
result, _, _ = pandas_handler._maybe_pandas_data(dummies, None, None)
exp = np.array([[1., 1., 0., 0.],
[2., 0., 1., 0.],
[3., 0., 0., 1.]])
@ -113,12 +114,12 @@ class TestPandas(unittest.TestCase):
import pandas as pd
rows = 100
X = pd.DataFrame(
{"A": pd.SparseArray(np.random.randint(0, 10, size=rows)),
"B": pd.SparseArray(np.random.randn(rows)),
"C": pd.SparseArray(np.random.permutation(
{"A": pd.arrays.SparseArray(np.random.randint(0, 10, size=rows)),
"B": pd.arrays.SparseArray(np.random.randn(rows)),
"C": pd.arrays.SparseArray(np.random.permutation(
[True, False] * (rows // 2)))}
)
y = pd.Series(pd.SparseArray(np.random.randn(rows)))
y = pd.Series(pd.arrays.SparseArray(np.random.randn(rows)))
dtrain = xgb.DMatrix(X, y)
booster = xgb.train({}, dtrain, num_boost_round=4)
predt_sparse = booster.predict(xgb.DMatrix(X))
@ -128,16 +129,17 @@ class TestPandas(unittest.TestCase):
def test_pandas_label(self):
# label must be a single column
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df,
pandas_handler = xgb.data.PandasHandler(np.nan, 0, False)
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
None, None, 'label', 'float')
# label must be supported dtype
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df,
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
None, None, 'label', 'float')
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
result, _, _ = xgb.core._maybe_pandas_data(df, None, None,
result, _, _ = pandas_handler._maybe_pandas_data(df, None, None,
'label', 'float')
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
dtype=float))