Implement Python data handler. (#5689)
* Define data handlers for DMatrix. * Throw ValueError in scikit learn interface.
This commit is contained in:
parent
646def51e0
commit
5af8161a1a
@ -14,7 +14,7 @@ from sklearn.datasets import load_iris, load_digits, load_boston
|
|||||||
rng = np.random.RandomState(31337)
|
rng = np.random.RandomState(31337)
|
||||||
|
|
||||||
print("Zeros and Ones from the Digits dataset: binary classification")
|
print("Zeros and Ones from the Digits dataset: binary classification")
|
||||||
digits = load_digits(2)
|
digits = load_digits(n_class=2)
|
||||||
y = digits['target']
|
y = digits['target']
|
||||||
X = digits['data']
|
X = digits['data']
|
||||||
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
|
||||||
|
|||||||
@ -107,7 +107,6 @@ except ImportError:
|
|||||||
try:
|
try:
|
||||||
from cudf import DataFrame as CUDF_DataFrame
|
from cudf import DataFrame as CUDF_DataFrame
|
||||||
from cudf import Series as CUDF_Series
|
from cudf import Series as CUDF_Series
|
||||||
from cudf import MultiIndex as CUDF_MultiIndex
|
|
||||||
from cudf import concat as CUDF_concat
|
from cudf import concat as CUDF_concat
|
||||||
CUDF_INSTALLED = True
|
CUDF_INSTALLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
||||||
# pylint: disable=too-many-branches, too-many-lines, too-many-locals
|
# pylint: disable=too-many-lines, too-many-locals
|
||||||
# pylint: disable=too-many-public-methods
|
|
||||||
"""Core XGBoost Library."""
|
"""Core XGBoost Library."""
|
||||||
import collections
|
import collections
|
||||||
# pylint: disable=no-name-in-module,import-error
|
# pylint: disable=no-name-in-module,import-error
|
||||||
@ -11,16 +10,15 @@ import ctypes
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
|
|
||||||
from .compat import (
|
from .compat import (
|
||||||
STRING_TYPES, DataFrame, MultiIndex, Int64Index, py_str,
|
STRING_TYPES, DataFrame, py_str,
|
||||||
PANDAS_INSTALLED, CUDF_INSTALLED,
|
PANDAS_INSTALLED, CUDF_INSTALLED,
|
||||||
CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex,
|
CUDF_DataFrame,
|
||||||
os_fspath, os_PathLike, lazy_isinstance)
|
os_fspath, os_PathLike, lazy_isinstance)
|
||||||
from .libpath import find_lib_path
|
from .libpath import find_lib_path
|
||||||
|
|
||||||
@ -262,10 +260,24 @@ def c_array(ctype, values):
|
|||||||
return (ctype * len(values))(*values)
|
return (ctype * len(values))(*values)
|
||||||
|
|
||||||
|
|
||||||
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64':
|
def _convert_unknown_data(data, meta=None, meta_type=None):
|
||||||
'int', 'uint8': 'int', 'uint16': 'int', 'uint32': 'int',
|
if meta is not None:
|
||||||
'uint64': 'int', 'float16': 'float', 'float32': 'float',
|
try:
|
||||||
'float64': 'float', 'bool': 'i'}
|
data = np.array(data, dtype=meta_type)
|
||||||
|
except Exception:
|
||||||
|
raise TypeError('Can not handle data from {}'.format(
|
||||||
|
type(data).__name__))
|
||||||
|
else:
|
||||||
|
import warnings
|
||||||
|
warnings.warn(
|
||||||
|
'Unknown data type: ' + str(type(data)) +
|
||||||
|
', coverting it to csr_matrix')
|
||||||
|
try:
|
||||||
|
data = scipy.sparse.csr_matrix(data)
|
||||||
|
except Exception:
|
||||||
|
raise TypeError('Can not initialize DMatrix from'
|
||||||
|
' {}'.format(type(data).__name__))
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
# Either object has cuda array interface or contains columns with interfaces
|
# Either object has cuda array interface or contains columns with interfaces
|
||||||
@ -274,184 +286,22 @@ def _has_cuda_array_interface(data):
|
|||||||
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))
|
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))
|
||||||
|
|
||||||
|
|
||||||
def _maybe_pandas_data(data, feature_names, feature_types,
|
|
||||||
meta=None, meta_type=None):
|
|
||||||
"""Extract internal data from pd.DataFrame for DMatrix data"""
|
|
||||||
if not (PANDAS_INSTALLED and isinstance(data, DataFrame)):
|
|
||||||
return data, feature_names, feature_types
|
|
||||||
from pandas.api.types import is_sparse
|
|
||||||
|
|
||||||
data_dtypes = data.dtypes
|
|
||||||
if not all(dtype.name in PANDAS_DTYPE_MAPPER or is_sparse(dtype)
|
|
||||||
for dtype in data_dtypes):
|
|
||||||
bad_fields = [
|
|
||||||
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
|
|
||||||
if dtype.name not in PANDAS_DTYPE_MAPPER
|
|
||||||
]
|
|
||||||
|
|
||||||
msg = """DataFrame.dtypes for data must be int, float or bool.
|
|
||||||
Did not expect the data types in fields """
|
|
||||||
raise ValueError(msg + ', '.join(bad_fields))
|
|
||||||
|
|
||||||
if feature_names is None and meta is None:
|
|
||||||
if isinstance(data.columns, MultiIndex):
|
|
||||||
feature_names = [
|
|
||||||
' '.join([str(x) for x in i]) for i in data.columns
|
|
||||||
]
|
|
||||||
elif isinstance(data.columns, Int64Index):
|
|
||||||
feature_names = list(map(str, data.columns))
|
|
||||||
else:
|
|
||||||
feature_names = data.columns.format()
|
|
||||||
|
|
||||||
if feature_types is None and meta is None:
|
|
||||||
feature_types = []
|
|
||||||
for dtype in data_dtypes:
|
|
||||||
if is_sparse(dtype):
|
|
||||||
feature_types.append(PANDAS_DTYPE_MAPPER[dtype.subtype.name])
|
|
||||||
else:
|
|
||||||
feature_types.append(PANDAS_DTYPE_MAPPER[dtype.name])
|
|
||||||
|
|
||||||
if meta and len(data.columns) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
'DataFrame for {meta} cannot have multiple columns'.format(
|
|
||||||
meta=meta))
|
|
||||||
|
|
||||||
dtype = meta_type if meta_type else 'float'
|
|
||||||
data = data.values.astype(dtype)
|
|
||||||
|
|
||||||
return data, feature_names, feature_types
|
|
||||||
|
|
||||||
|
|
||||||
def _cudf_array_interfaces(df):
|
def _cudf_array_interfaces(df):
|
||||||
'''Extract CuDF __cuda_array_interface__'''
|
'''Extract CuDF __cuda_array_interface__'''
|
||||||
interfaces = []
|
interfaces = []
|
||||||
for col in df:
|
if lazy_isinstance(df, 'cudf.core.series', 'Series'):
|
||||||
interface = df[col].__cuda_array_interface__
|
interfaces.append(df.__cuda_array_interface__)
|
||||||
if 'mask' in interface:
|
else:
|
||||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
for col in df:
|
||||||
interfaces.append(interface)
|
interface = df[col].__cuda_array_interface__
|
||||||
|
if 'mask' in interface:
|
||||||
|
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||||
|
interfaces.append(interface)
|
||||||
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
|
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
|
||||||
return interfaces_str
|
return interfaces_str
|
||||||
|
|
||||||
|
|
||||||
def _maybe_cudf_dataframe(data, feature_names, feature_types):
|
class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||||
"""Extract internal data from cudf.DataFrame for DMatrix data."""
|
|
||||||
if not (CUDF_INSTALLED and isinstance(data,
|
|
||||||
(CUDF_DataFrame, CUDF_Series))):
|
|
||||||
return data, feature_names, feature_types
|
|
||||||
if feature_names is None:
|
|
||||||
if isinstance(data, CUDF_Series):
|
|
||||||
feature_names = [data.name]
|
|
||||||
elif isinstance(data.columns, CUDF_MultiIndex):
|
|
||||||
feature_names = [
|
|
||||||
' '.join([str(x) for x in i])
|
|
||||||
for i in data.columns
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
feature_names = data.columns.format()
|
|
||||||
if feature_types is None:
|
|
||||||
if isinstance(data, CUDF_Series):
|
|
||||||
dtypes = [data.dtype]
|
|
||||||
else:
|
|
||||||
dtypes = data.dtypes
|
|
||||||
feature_types = [PANDAS_DTYPE_MAPPER[d.name] for d in dtypes]
|
|
||||||
return data, feature_names, feature_types
|
|
||||||
|
|
||||||
|
|
||||||
DT_TYPE_MAPPER = {'bool': 'bool', 'int': 'int', 'real': 'float'}
|
|
||||||
|
|
||||||
DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_dt_data(data, feature_names, feature_types,
|
|
||||||
meta=None, meta_type=None):
|
|
||||||
"""Validate feature names and types if data table"""
|
|
||||||
if (not lazy_isinstance(data, 'datatable', 'Frame') and
|
|
||||||
not lazy_isinstance(data, 'datatable', 'DataTable')):
|
|
||||||
return data, feature_names, feature_types
|
|
||||||
|
|
||||||
if meta and data.shape[1] > 1:
|
|
||||||
raise ValueError(
|
|
||||||
'DataTable for label or weight cannot have multiple columns')
|
|
||||||
if meta:
|
|
||||||
# below requires new dt version
|
|
||||||
# extract first column
|
|
||||||
data = data.to_numpy()[:, 0].astype(meta_type)
|
|
||||||
return data, None, None
|
|
||||||
|
|
||||||
data_types_names = tuple(lt.name for lt in data.ltypes)
|
|
||||||
bad_fields = [data.names[i]
|
|
||||||
for i, type_name in enumerate(data_types_names)
|
|
||||||
if type_name not in DT_TYPE_MAPPER]
|
|
||||||
if bad_fields:
|
|
||||||
msg = """DataFrame.types for data must be int, float or bool.
|
|
||||||
Did not expect the data types in fields """
|
|
||||||
raise ValueError(msg + ', '.join(bad_fields))
|
|
||||||
|
|
||||||
if feature_names is None and meta is None:
|
|
||||||
feature_names = data.names
|
|
||||||
|
|
||||||
# always return stypes for dt ingestion
|
|
||||||
if feature_types is not None:
|
|
||||||
raise ValueError(
|
|
||||||
'DataTable has own feature types, cannot pass them in.')
|
|
||||||
feature_types = np.vectorize(DT_TYPE_MAPPER2.get)(data_types_names)
|
|
||||||
|
|
||||||
return data, feature_names, feature_types
|
|
||||||
|
|
||||||
def _is_dlpack(x):
|
|
||||||
return 'PyCapsule' in str(type(x)) and "dltensor" in str(x)
|
|
||||||
|
|
||||||
# Just convert dlpack into cupy (zero copy)
|
|
||||||
def _maybe_dlpack_data(data, feature_names, feature_types):
|
|
||||||
if not _is_dlpack(data):
|
|
||||||
return data, feature_names, feature_types
|
|
||||||
from cupy import fromDlpack # pylint: disable=E0401
|
|
||||||
data = fromDlpack(data)
|
|
||||||
return data, feature_names, feature_types
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_dataframes(data, feature_names, feature_types,
|
|
||||||
meta=None, meta_type=None):
|
|
||||||
data, feature_names, feature_types = _maybe_pandas_data(data,
|
|
||||||
feature_names,
|
|
||||||
feature_types,
|
|
||||||
meta,
|
|
||||||
meta_type)
|
|
||||||
|
|
||||||
data, feature_names, feature_types = _maybe_dt_data(data,
|
|
||||||
feature_names,
|
|
||||||
feature_types,
|
|
||||||
meta,
|
|
||||||
meta_type)
|
|
||||||
|
|
||||||
data, feature_names, feature_types = _maybe_cudf_dataframe(
|
|
||||||
data, feature_names, feature_types)
|
|
||||||
|
|
||||||
data, feature_names, feature_types = _maybe_dlpack_data(
|
|
||||||
data, feature_names, feature_types)
|
|
||||||
|
|
||||||
return data, feature_names, feature_types
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_np_slice(data, dtype=np.float32):
|
|
||||||
'''Handle numpy slice. This can be removed if we use __array_interface__.
|
|
||||||
'''
|
|
||||||
try:
|
|
||||||
if not data.flags.c_contiguous:
|
|
||||||
warnings.warn(
|
|
||||||
"Use subset (sliced data) of np.ndarray is not recommended " +
|
|
||||||
"because it will generate extra copies and increase " +
|
|
||||||
"memory consumption")
|
|
||||||
data = np.array(data, copy=True, dtype=dtype)
|
|
||||||
else:
|
|
||||||
data = np.array(data, copy=False, dtype=dtype)
|
|
||||||
except AttributeError:
|
|
||||||
data = np.array(data, copy=False, dtype=dtype)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class DMatrix(object):
|
|
||||||
"""Data Matrix used in XGBoost.
|
"""Data Matrix used in XGBoost.
|
||||||
|
|
||||||
DMatrix is a internal data structure that used by XGBoost
|
DMatrix is a internal data structure that used by XGBoost
|
||||||
@ -503,6 +353,13 @@ class DMatrix(object):
|
|||||||
applicable. If -1, uses maximum threads available on the system.
|
applicable. If -1, uses maximum threads available on the system.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if isinstance(data, list):
|
||||||
|
raise TypeError('Input data can not be a list.')
|
||||||
|
|
||||||
|
self.missing = missing if missing is not None else np.nan
|
||||||
|
self.nthread = nthread if nthread is not None else 1
|
||||||
|
self.silent = silent
|
||||||
|
|
||||||
# force into void_p, mac need to pass things in as void_p
|
# force into void_p, mac need to pass things in as void_p
|
||||||
if data is None:
|
if data is None:
|
||||||
self.handle = None
|
self.handle = None
|
||||||
@ -513,40 +370,13 @@ class DMatrix(object):
|
|||||||
self._feature_types = feature_types
|
self._feature_types = feature_types
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(data, list):
|
handler = self.get_data_handler(data)
|
||||||
raise TypeError('Input data can not be a list.')
|
if handler is None:
|
||||||
|
data = _convert_unknown_data(data, None)
|
||||||
data, feature_names, feature_types = _convert_dataframes(
|
handler = self.get_data_handler(data)
|
||||||
data, feature_names, feature_types
|
self.handle, feature_names, feature_types = handler.handle_input(
|
||||||
)
|
data, feature_names, feature_types)
|
||||||
missing = missing if missing is not None else np.nan
|
assert self.handle, 'Failed to construct a DMatrix.'
|
||||||
nthread = nthread if nthread is not None else 1
|
|
||||||
|
|
||||||
if isinstance(data, (STRING_TYPES, os_PathLike)):
|
|
||||||
handle = ctypes.c_void_p()
|
|
||||||
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
|
|
||||||
ctypes.c_int(silent),
|
|
||||||
ctypes.byref(handle)))
|
|
||||||
self.handle = handle
|
|
||||||
elif isinstance(data, scipy.sparse.csr_matrix):
|
|
||||||
self._init_from_csr(data)
|
|
||||||
elif isinstance(data, scipy.sparse.csc_matrix):
|
|
||||||
self._init_from_csc(data)
|
|
||||||
elif isinstance(data, np.ndarray):
|
|
||||||
self._init_from_npy2d(data, missing, nthread)
|
|
||||||
elif lazy_isinstance(data, 'datatable', 'Frame'):
|
|
||||||
self._init_from_dt(data, nthread)
|
|
||||||
elif hasattr(data, "__cuda_array_interface__"):
|
|
||||||
self._init_from_array_interface(data, missing, nthread)
|
|
||||||
elif CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
|
|
||||||
self._init_from_array_interface_columns(data, missing, nthread)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
csr = scipy.sparse.csr_matrix(data)
|
|
||||||
self._init_from_csr(csr)
|
|
||||||
except Exception:
|
|
||||||
raise TypeError('can not initialize DMatrix from'
|
|
||||||
' {}'.format(type(data).__name__))
|
|
||||||
|
|
||||||
if label is not None:
|
if label is not None:
|
||||||
self.set_label(label)
|
self.set_label(label)
|
||||||
@ -558,126 +388,12 @@ class DMatrix(object):
|
|||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
|
|
||||||
def _init_from_csr(self, csr):
|
def get_data_handler(self, data, meta=None, meta_type=None):
|
||||||
"""Initialize data from a CSR matrix."""
|
'''Get data handler for this DMatrix class.'''
|
||||||
if len(csr.indices) != len(csr.data):
|
from .data import get_dmatrix_data_handler
|
||||||
raise ValueError('length mismatch: {} vs {}'.format(
|
handler = get_dmatrix_data_handler(
|
||||||
len(csr.indices), len(csr.data)))
|
data, self.missing, self.nthread, self.silent, meta, meta_type)
|
||||||
handle = ctypes.c_void_p()
|
return handler
|
||||||
_check_call(_LIB.XGDMatrixCreateFromCSREx(
|
|
||||||
c_array(ctypes.c_size_t, csr.indptr),
|
|
||||||
c_array(ctypes.c_uint, csr.indices),
|
|
||||||
c_array(ctypes.c_float, csr.data),
|
|
||||||
ctypes.c_size_t(len(csr.indptr)),
|
|
||||||
ctypes.c_size_t(len(csr.data)),
|
|
||||||
ctypes.c_size_t(csr.shape[1]),
|
|
||||||
ctypes.byref(handle)))
|
|
||||||
self.handle = handle
|
|
||||||
|
|
||||||
def _init_from_csc(self, csc):
|
|
||||||
"""Initialize data from a CSC matrix."""
|
|
||||||
if len(csc.indices) != len(csc.data):
|
|
||||||
raise ValueError('length mismatch: {} vs {}'.format(
|
|
||||||
len(csc.indices), len(csc.data)))
|
|
||||||
handle = ctypes.c_void_p()
|
|
||||||
_check_call(_LIB.XGDMatrixCreateFromCSCEx(
|
|
||||||
c_array(ctypes.c_size_t, csc.indptr),
|
|
||||||
c_array(ctypes.c_uint, csc.indices),
|
|
||||||
c_array(ctypes.c_float, csc.data),
|
|
||||||
ctypes.c_size_t(len(csc.indptr)),
|
|
||||||
ctypes.c_size_t(len(csc.data)),
|
|
||||||
ctypes.c_size_t(csc.shape[0]),
|
|
||||||
ctypes.byref(handle)))
|
|
||||||
self.handle = handle
|
|
||||||
|
|
||||||
def _init_from_npy2d(self, mat, missing, nthread):
|
|
||||||
"""Initialize data from a 2-D numpy matrix.
|
|
||||||
|
|
||||||
If ``mat`` does not have ``order='C'`` (aka row-major) or is
|
|
||||||
not contiguous, a temporary copy will be made.
|
|
||||||
|
|
||||||
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will
|
|
||||||
be made.
|
|
||||||
|
|
||||||
So there could be as many as two temporary data copies; be mindful of
|
|
||||||
input layout and type if memory use is a concern.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if len(mat.shape) != 2:
|
|
||||||
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
|
|
||||||
mat.shape)
|
|
||||||
# flatten the array by rows and ensure it is float32. we try to avoid
|
|
||||||
# data copies if possible (reshape returns a view when possible and we
|
|
||||||
# explicitly tell np.array to try and avoid copying)
|
|
||||||
data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32)
|
|
||||||
handle = ctypes.c_void_p()
|
|
||||||
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
|
|
||||||
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
|
||||||
c_bst_ulong(mat.shape[0]),
|
|
||||||
c_bst_ulong(mat.shape[1]),
|
|
||||||
ctypes.c_float(missing),
|
|
||||||
ctypes.byref(handle),
|
|
||||||
ctypes.c_int(nthread)))
|
|
||||||
self.handle = handle
|
|
||||||
|
|
||||||
def _init_from_dt(self, data, nthread):
|
|
||||||
"""Initialize data from a datatable Frame."""
|
|
||||||
ptrs = (ctypes.c_void_p * data.ncols)()
|
|
||||||
if hasattr(data, "internal") and hasattr(data.internal, "column"):
|
|
||||||
# datatable>0.8.0
|
|
||||||
for icol in range(data.ncols):
|
|
||||||
col = data.internal.column(icol)
|
|
||||||
ptr = col.data_pointer
|
|
||||||
ptrs[icol] = ctypes.c_void_p(ptr)
|
|
||||||
else:
|
|
||||||
# datatable<=0.8.0
|
|
||||||
from datatable.internal import \
|
|
||||||
frame_column_data_r # pylint: disable=no-name-in-module,import-error
|
|
||||||
for icol in range(data.ncols):
|
|
||||||
ptrs[icol] = frame_column_data_r(data, icol)
|
|
||||||
|
|
||||||
# always return stypes for dt ingestion
|
|
||||||
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
|
||||||
for icol in range(data.ncols):
|
|
||||||
feature_type_strings[icol] = ctypes.c_char_p(
|
|
||||||
data.stypes[icol].name.encode('utf-8'))
|
|
||||||
|
|
||||||
handle = ctypes.c_void_p()
|
|
||||||
_check_call(_LIB.XGDMatrixCreateFromDT(
|
|
||||||
ptrs, feature_type_strings,
|
|
||||||
c_bst_ulong(data.shape[0]),
|
|
||||||
c_bst_ulong(data.shape[1]),
|
|
||||||
ctypes.byref(handle),
|
|
||||||
ctypes.c_int(nthread)))
|
|
||||||
self.handle = handle
|
|
||||||
|
|
||||||
def _init_from_array_interface_columns(self, df, missing, nthread):
|
|
||||||
"""Initialize DMatrix from columnar memory format."""
|
|
||||||
interfaces_str = _cudf_array_interfaces(df)
|
|
||||||
handle = ctypes.c_void_p()
|
|
||||||
_check_call(
|
|
||||||
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
|
|
||||||
interfaces_str,
|
|
||||||
ctypes.c_float(missing),
|
|
||||||
ctypes.c_int(nthread),
|
|
||||||
ctypes.byref(handle)))
|
|
||||||
self.handle = handle
|
|
||||||
|
|
||||||
def _init_from_array_interface(self, data, missing, nthread):
|
|
||||||
"""Initialize DMatrix from cupy ndarray."""
|
|
||||||
interface = data.__cuda_array_interface__
|
|
||||||
if 'mask' in interface:
|
|
||||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
|
||||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
|
||||||
|
|
||||||
handle = ctypes.c_void_p()
|
|
||||||
_check_call(
|
|
||||||
_LIB.XGDMatrixCreateFromArrayInterface(
|
|
||||||
interface_str,
|
|
||||||
ctypes.c_float(missing),
|
|
||||||
ctypes.c_int(nthread),
|
|
||||||
ctypes.byref(handle)))
|
|
||||||
self.handle = handle
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if hasattr(self, "handle") and self.handle:
|
if hasattr(self, "handle") and self.handle:
|
||||||
@ -737,10 +453,14 @@ class DMatrix(object):
|
|||||||
data: numpy array
|
data: numpy array
|
||||||
The array of data to be set
|
The array of data to be set
|
||||||
"""
|
"""
|
||||||
data, _, _ = _convert_dataframes(data, None, None, field, 'float')
|
|
||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
self.set_float_info_npy2d(field, data)
|
self.set_float_info_npy2d(field, data)
|
||||||
return
|
return
|
||||||
|
handler = self.get_data_handler(data, field, np.float32)
|
||||||
|
if handler is None:
|
||||||
|
data = _convert_unknown_data(data, field, np.float32)
|
||||||
|
handler = self.get_data_handler(data, field, np.float32)
|
||||||
|
data, _, _ = handler.transform(data)
|
||||||
c_data = c_array(ctypes.c_float, data)
|
c_data = c_array(ctypes.c_float, data)
|
||||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||||
c_str(field),
|
c_str(field),
|
||||||
@ -759,7 +479,8 @@ class DMatrix(object):
|
|||||||
data: numpy array
|
data: numpy array
|
||||||
The array of data to be set
|
The array of data to be set
|
||||||
"""
|
"""
|
||||||
data = _maybe_np_slice(data, np.float32)
|
data, _, _ = self.get_data_handler(
|
||||||
|
data, field, np.float32).transform(data)
|
||||||
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||||
c_str(field),
|
c_str(field),
|
||||||
@ -777,9 +498,8 @@ class DMatrix(object):
|
|||||||
data: numpy array
|
data: numpy array
|
||||||
The array of data to be set
|
The array of data to be set
|
||||||
"""
|
"""
|
||||||
data = _maybe_np_slice(data, np.uint32)
|
data, _, _ = self.get_data_handler(
|
||||||
data, _, _ = _convert_dataframes(data, None, None, field, 'uint32')
|
data, field, 'uint32').transform(data)
|
||||||
data = np.array(data, copy=False, dtype=ctypes.c_uint)
|
|
||||||
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
|
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
|
||||||
c_str(field),
|
c_str(field),
|
||||||
c_array(ctypes.c_uint, data),
|
c_array(ctypes.c_uint, data),
|
||||||
@ -1075,46 +795,18 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
feature_types=None,
|
feature_types=None,
|
||||||
nthread=None, max_bin=256):
|
nthread=None, max_bin=256):
|
||||||
self.max_bin = max_bin
|
self.max_bin = max_bin
|
||||||
if not (hasattr(data, "__cuda_array_interface__") or (
|
super().__init__(data, label=label, weight=weight,
|
||||||
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame)) or _is_dlpack(data)):
|
base_margin=base_margin,
|
||||||
raise ValueError('Only cupy/cudf/dlpack currently supported for DeviceQuantileDMatrix')
|
|
||||||
|
|
||||||
super().__init__(data, label=label, weight=weight, base_margin=base_margin,
|
|
||||||
missing=missing,
|
missing=missing,
|
||||||
silent=silent,
|
silent=silent,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
nthread=nthread)
|
nthread=nthread)
|
||||||
|
|
||||||
def _init_from_array_interface_columns(self, df, missing, nthread):
|
def get_data_handler(self, data, meta=None, meta_type=None):
|
||||||
"""Initialize DMatrix from columnar memory format."""
|
from .data import get_device_quantile_dmatrix_data_handler
|
||||||
interfaces_str = _cudf_array_interfaces(df)
|
return get_device_quantile_dmatrix_data_handler(
|
||||||
handle = ctypes.c_void_p()
|
data, self.max_bin, self.missing, self.nthread, self.silent)
|
||||||
missing = missing if missing is not None else np.nan
|
|
||||||
nthread = nthread if nthread is not None else 1
|
|
||||||
_check_call(
|
|
||||||
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
|
|
||||||
interfaces_str,
|
|
||||||
ctypes.c_float(missing), ctypes.c_int(nthread),
|
|
||||||
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
|
|
||||||
self.handle = handle
|
|
||||||
|
|
||||||
def _init_from_array_interface(self, data, missing, nthread):
|
|
||||||
"""Initialize DMatrix from cupy ndarray."""
|
|
||||||
interface = data.__cuda_array_interface__
|
|
||||||
if 'mask' in interface:
|
|
||||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
|
||||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
|
||||||
|
|
||||||
handle = ctypes.c_void_p()
|
|
||||||
missing = missing if missing is not None else np.nan
|
|
||||||
nthread = nthread if nthread is not None else 1
|
|
||||||
_check_call(
|
|
||||||
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
|
|
||||||
interface_str,
|
|
||||||
ctypes.c_float(missing), ctypes.c_int(nthread),
|
|
||||||
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
|
|
||||||
self.handle = handle
|
|
||||||
|
|
||||||
|
|
||||||
class Booster(object):
|
class Booster(object):
|
||||||
@ -1467,6 +1159,7 @@ class Booster(object):
|
|||||||
self._validate_features(data)
|
self._validate_features(data)
|
||||||
return self.eval_set([(data, name)], iteration)
|
return self.eval_set([(data, name)], iteration)
|
||||||
|
|
||||||
|
# pylint: disable=too-many-function-args
|
||||||
def predict(self,
|
def predict(self,
|
||||||
data,
|
data,
|
||||||
output_margin=False,
|
output_margin=False,
|
||||||
|
|||||||
624
python-package/xgboost/data.py
Normal file
624
python-package/xgboost/data.py
Normal file
@ -0,0 +1,624 @@
|
|||||||
|
# pylint: disable=too-many-arguments, no-self-use
|
||||||
|
'''Data dispatching for DMatrix.'''
|
||||||
|
import ctypes
|
||||||
|
import abc
|
||||||
|
import json
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .core import c_array, _LIB, _check_call, c_str, _cudf_array_interfaces
|
||||||
|
from .compat import lazy_isinstance, STRING_TYPES, os_fspath, os_PathLike
|
||||||
|
|
||||||
|
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class DataHandler(abc.ABC):
|
||||||
|
'''Base class for various data handler.'''
|
||||||
|
def __init__(self, missing, nthread, silent, meta=None, meta_type=None):
|
||||||
|
self.missing = missing
|
||||||
|
self.nthread = nthread
|
||||||
|
self.silent = silent
|
||||||
|
|
||||||
|
self.meta = meta
|
||||||
|
self.meta_type = meta_type
|
||||||
|
|
||||||
|
def _warn_unused_missing(self, data):
|
||||||
|
if not (np.isnan(np.nan) or None):
|
||||||
|
warnings.warn(
|
||||||
|
'`missing` is not used for current input data type:' +
|
||||||
|
str(type(data)))
|
||||||
|
|
||||||
|
def check_complex(self, data):
|
||||||
|
'''Test whether data is complex using `dtype` attribute.'''
|
||||||
|
complex_dtypes = (np.complex128, np.complex64,
|
||||||
|
np.cfloat, np.cdouble, np.clongdouble)
|
||||||
|
if hasattr(data, 'dtype') and data.dtype in complex_dtypes:
|
||||||
|
raise ValueError('Complex data not supported')
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
'''Optional method for transforming data before being accepted by
|
||||||
|
other XGBoost API.'''
|
||||||
|
return data, None, None
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
'''Abstract method for handling different data input.'''
|
||||||
|
|
||||||
|
|
||||||
|
class DMatrixDataManager:
|
||||||
|
'''The registry class for various data handler.'''
|
||||||
|
def __init__(self):
|
||||||
|
self.__data_handlers = {}
|
||||||
|
self.__data_handlers_dly = []
|
||||||
|
|
||||||
|
def register_handler(self, module, name, handler):
|
||||||
|
'''Register a data handler handling specfic type of data.'''
|
||||||
|
self.__data_handlers['.'.join([module, name])] = handler
|
||||||
|
|
||||||
|
def register_handler_opaque(self, func, handler):
|
||||||
|
'''Register a data handler that handles data with opaque type.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
func : callable
|
||||||
|
A function with a single parameter `data`. It should return True
|
||||||
|
if the handler can handle this data, otherwise returns False.
|
||||||
|
handler : xgboost.data.DataHandler
|
||||||
|
The handler class that is a subclass of `DataHandler`.
|
||||||
|
'''
|
||||||
|
self.__data_handlers_dly.append((func, handler))
|
||||||
|
|
||||||
|
def get_handler(self, data):
|
||||||
|
'''Get a handler of `data`, returns None if handler not found.'''
|
||||||
|
module, name = type(data).__module__, type(data).__name__
|
||||||
|
if '.'.join([module, name]) in self.__data_handlers.keys():
|
||||||
|
handler = self.__data_handlers['.'.join([module, name])]
|
||||||
|
return handler
|
||||||
|
for f, handler in self.__data_handlers_dly:
|
||||||
|
if f(data):
|
||||||
|
return handler
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def get_dmatrix_data_handler(data, missing, nthread, silent,
|
||||||
|
meta=None, meta_type=None):
|
||||||
|
'''Get a handler of `data` for DMatrix.
|
||||||
|
|
||||||
|
.. versionadded:: 1.2.0
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : any
|
||||||
|
The input data.
|
||||||
|
missing : float
|
||||||
|
Same as `missing` for DMatrix.
|
||||||
|
nthread : int
|
||||||
|
Same as `nthread` for DMatrix.
|
||||||
|
silent : boolean
|
||||||
|
Same as `silent` for DMatrix.
|
||||||
|
meta : str
|
||||||
|
Field name of meta data, like `label`. Used only for getting handler
|
||||||
|
for meta info.
|
||||||
|
meta_type : str/np.dtype
|
||||||
|
Type of meta data.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
handler : DataHandler
|
||||||
|
'''
|
||||||
|
handler = __dmatrix_registry.get_handler(data)
|
||||||
|
if handler is None:
|
||||||
|
return None
|
||||||
|
return handler(missing, nthread, silent, meta, meta_type)
|
||||||
|
|
||||||
|
|
||||||
|
class FileHandler(DataHandler):
|
||||||
|
'''Handler of path like input.'''
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
self._warn_unused_missing(data)
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
|
||||||
|
ctypes.c_int(self.silent),
|
||||||
|
ctypes.byref(handle)))
|
||||||
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry.register_handler_opaque(
|
||||||
|
lambda data: isinstance(data, (STRING_TYPES, os_PathLike)),
|
||||||
|
FileHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class CSRHandler(DataHandler):
|
||||||
|
'''Handler of `scipy.sparse.csr.csr_matrix`.'''
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
'''Initialize data from a CSR matrix.'''
|
||||||
|
if len(data.indices) != len(data.data):
|
||||||
|
raise ValueError('length mismatch: {} vs {}'.format(
|
||||||
|
len(data.indices), len(data.data)))
|
||||||
|
self._warn_unused_missing(data)
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
_check_call(_LIB.XGDMatrixCreateFromCSREx(
|
||||||
|
c_array(ctypes.c_size_t, data.indptr),
|
||||||
|
c_array(ctypes.c_uint, data.indices),
|
||||||
|
c_array(ctypes.c_float, data.data),
|
||||||
|
ctypes.c_size_t(len(data.indptr)),
|
||||||
|
ctypes.c_size_t(len(data.data)),
|
||||||
|
ctypes.c_size_t(data.shape[1]),
|
||||||
|
ctypes.byref(handle)))
|
||||||
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry.register_handler(
|
||||||
|
'scipy.sparse.csr', 'csr_matrix', CSRHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class CSCHandler(DataHandler):
|
||||||
|
'''Handler of `scipy.sparse.csc.csc_matrix`.'''
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
if len(data.indices) != len(data.data):
|
||||||
|
raise ValueError('length mismatch: {} vs {}'.format(
|
||||||
|
len(data.indices), len(data.data)))
|
||||||
|
self._warn_unused_missing(data)
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
_check_call(_LIB.XGDMatrixCreateFromCSCEx(
|
||||||
|
c_array(ctypes.c_size_t, data.indptr),
|
||||||
|
c_array(ctypes.c_uint, data.indices),
|
||||||
|
c_array(ctypes.c_float, data.data),
|
||||||
|
ctypes.c_size_t(len(data.indptr)),
|
||||||
|
ctypes.c_size_t(len(data.data)),
|
||||||
|
ctypes.c_size_t(data.shape[0]),
|
||||||
|
ctypes.byref(handle)))
|
||||||
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry.register_handler(
|
||||||
|
'scipy.sparse.csc', 'csc_matrix', CSCHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyHandler(DataHandler):
|
||||||
|
'''Handler of `numpy.ndarray`.'''
|
||||||
|
def _maybe_np_slice(self, data, dtype):
|
||||||
|
'''Handle numpy slice. This can be removed if we use __array_interface__.
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
if not data.flags.c_contiguous:
|
||||||
|
warnings.warn(
|
||||||
|
"Use subset (sliced data) of np.ndarray is not recommended " +
|
||||||
|
"because it will generate extra copies and increase " +
|
||||||
|
"memory consumption")
|
||||||
|
data = np.array(data, copy=True, dtype=dtype)
|
||||||
|
else:
|
||||||
|
data = np.array(data, copy=False, dtype=dtype)
|
||||||
|
except AttributeError:
|
||||||
|
data = np.array(data, copy=False, dtype=dtype)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
return self._maybe_np_slice(data, self.meta_type), None, None
|
||||||
|
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
"""Initialize data from a 2-D numpy matrix.
|
||||||
|
|
||||||
|
If ``mat`` does not have ``order='C'`` (aka row-major) or is
|
||||||
|
not contiguous, a temporary copy will be made.
|
||||||
|
|
||||||
|
If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will
|
||||||
|
be made.
|
||||||
|
|
||||||
|
So there could be as many as two temporary data copies; be mindful of
|
||||||
|
input layout and type if memory use is a concern.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not isinstance(data, np.ndarray) and hasattr(data, '__array__'):
|
||||||
|
data = np.array(data, copy=False)
|
||||||
|
if len(data.shape) != 2:
|
||||||
|
raise ValueError('Expecting 2 dimensional numpy.ndarray, got: ',
|
||||||
|
data.shape)
|
||||||
|
# flatten the array by rows and ensure it is float32. we try to avoid
|
||||||
|
# data copies if possible (reshape returns a view when possible and we
|
||||||
|
# explicitly tell np.array to try and avoid copying)
|
||||||
|
flatten = np.array(data.reshape(data.size), copy=False,
|
||||||
|
dtype=np.float32)
|
||||||
|
flatten = self._maybe_np_slice(flatten, np.float32)
|
||||||
|
self.check_complex(data)
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
|
||||||
|
flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||||
|
c_bst_ulong(data.shape[0]),
|
||||||
|
c_bst_ulong(data.shape[1]),
|
||||||
|
ctypes.c_float(self.missing),
|
||||||
|
ctypes.byref(handle),
|
||||||
|
ctypes.c_int(self.nthread)))
|
||||||
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry.register_handler('numpy', 'ndarray', NumpyHandler)
|
||||||
|
__dmatrix_registry.register_handler('numpy', 'matrix', NumpyHandler)
|
||||||
|
__dmatrix_registry.register_handler_opaque(
|
||||||
|
lambda x: hasattr(x, '__array__'), NumpyHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class ListHandler(NumpyHandler):
|
||||||
|
'''Handler of builtin list and tuple'''
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
assert self.meta is None, 'List input data is not supported for X'
|
||||||
|
data = np.array(data)
|
||||||
|
return super().handle_input(data, feature_names, feature_types)
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry.register_handler('builtins', 'list', NumpyHandler)
|
||||||
|
__dmatrix_registry.register_handler('builtins', 'tuple', NumpyHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class PandasHandler(NumpyHandler):
|
||||||
|
'''Handler of data structures defined by `pandas`.'''
|
||||||
|
pandas_dtype_mapper = {
|
||||||
|
'int8': 'int',
|
||||||
|
'int16': 'int',
|
||||||
|
'int32': 'int',
|
||||||
|
'int64': 'int',
|
||||||
|
'uint8': 'int',
|
||||||
|
'uint16': 'int',
|
||||||
|
'uint32': 'int',
|
||||||
|
'uint64': 'int',
|
||||||
|
'float16': 'float',
|
||||||
|
'float32': 'float',
|
||||||
|
'float64': 'float',
|
||||||
|
'bool': 'i'
|
||||||
|
}
|
||||||
|
|
||||||
|
def _maybe_pandas_data(self, data, feature_names, feature_types,
|
||||||
|
meta=None, meta_type=None):
|
||||||
|
"""Extract internal data from pd.DataFrame for DMatrix data"""
|
||||||
|
if lazy_isinstance(data, 'pandas.core.series', 'Series'):
|
||||||
|
dtype = meta_type if meta_type else 'float'
|
||||||
|
return data.values.astype(dtype), feature_names, feature_types
|
||||||
|
|
||||||
|
from pandas.api.types import is_sparse
|
||||||
|
from pandas import MultiIndex, Int64Index
|
||||||
|
|
||||||
|
data_dtypes = data.dtypes
|
||||||
|
if not all(dtype.name in self.pandas_dtype_mapper or is_sparse(dtype)
|
||||||
|
for dtype in data_dtypes):
|
||||||
|
bad_fields = [
|
||||||
|
str(data.columns[i]) for i, dtype in enumerate(data_dtypes)
|
||||||
|
if dtype.name not in self.pandas_dtype_mapper
|
||||||
|
]
|
||||||
|
|
||||||
|
msg = """DataFrame.dtypes for data must be int, float or bool.
|
||||||
|
Did not expect the data types in fields """
|
||||||
|
raise ValueError(msg + ', '.join(bad_fields))
|
||||||
|
|
||||||
|
if feature_names is None and meta is None:
|
||||||
|
if isinstance(data.columns, MultiIndex):
|
||||||
|
feature_names = [
|
||||||
|
' '.join([str(x) for x in i]) for i in data.columns
|
||||||
|
]
|
||||||
|
elif isinstance(data.columns, Int64Index):
|
||||||
|
feature_names = list(map(str, data.columns))
|
||||||
|
else:
|
||||||
|
feature_names = data.columns.format()
|
||||||
|
|
||||||
|
if feature_types is None and meta is None:
|
||||||
|
feature_types = []
|
||||||
|
for dtype in data_dtypes:
|
||||||
|
if is_sparse(dtype):
|
||||||
|
feature_types.append(self.pandas_dtype_mapper[
|
||||||
|
dtype.subtype.name])
|
||||||
|
else:
|
||||||
|
feature_types.append(self.pandas_dtype_mapper[dtype.name])
|
||||||
|
|
||||||
|
if meta and len(data.columns) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
'DataFrame for {meta} cannot have multiple columns'.format(
|
||||||
|
meta=meta))
|
||||||
|
|
||||||
|
dtype = meta_type if meta_type else 'float'
|
||||||
|
data = data.values.astype(dtype)
|
||||||
|
|
||||||
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
return self._maybe_pandas_data(data, None, None, self.meta,
|
||||||
|
self.meta_type)
|
||||||
|
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
data, feature_names, feature_types = self._maybe_pandas_data(
|
||||||
|
data, feature_names, feature_types, self.meta, self.meta_type)
|
||||||
|
return super().handle_input(data, feature_names, feature_types)
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry.register_handler(
|
||||||
|
'pandas.core.frame', 'DataFrame', PandasHandler)
|
||||||
|
__dmatrix_registry.register_handler(
|
||||||
|
'pandas.core.series', 'Series', PandasHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class DTHandler(DataHandler):
|
||||||
|
'''Handler of datatable.'''
|
||||||
|
dt_type_mapper = {'bool': 'bool', 'int': 'int', 'real': 'float'}
|
||||||
|
dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
||||||
|
|
||||||
|
def _maybe_dt_data(self, data, feature_names, feature_types,
|
||||||
|
meta=None, meta_type=None):
|
||||||
|
"""Validate feature names and types if data table"""
|
||||||
|
if meta and data.shape[1] > 1:
|
||||||
|
raise ValueError(
|
||||||
|
'DataTable for label or weight cannot have multiple columns')
|
||||||
|
if meta:
|
||||||
|
# below requires new dt version
|
||||||
|
# extract first column
|
||||||
|
data = data.to_numpy()[:, 0].astype(meta_type)
|
||||||
|
return data, None, None
|
||||||
|
|
||||||
|
data_types_names = tuple(lt.name for lt in data.ltypes)
|
||||||
|
bad_fields = [data.names[i]
|
||||||
|
for i, type_name in enumerate(data_types_names)
|
||||||
|
if type_name not in self.dt_type_mapper]
|
||||||
|
if bad_fields:
|
||||||
|
msg = """DataFrame.types for data must be int, float or bool.
|
||||||
|
Did not expect the data types in fields """
|
||||||
|
raise ValueError(msg + ', '.join(bad_fields))
|
||||||
|
|
||||||
|
if feature_names is None and meta is None:
|
||||||
|
feature_names = data.names
|
||||||
|
|
||||||
|
# always return stypes for dt ingestion
|
||||||
|
if feature_types is not None:
|
||||||
|
raise ValueError(
|
||||||
|
'DataTable has own feature types, cannot pass them in.')
|
||||||
|
feature_types = np.vectorize(self.dt_type_mapper2.get)(
|
||||||
|
data_types_names)
|
||||||
|
|
||||||
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
return self._maybe_dt_data(data, None, None, self.meta, self.meta_type)
|
||||||
|
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
data, feature_names, feature_types = self._maybe_dt_data(
|
||||||
|
data, feature_names, feature_types, self.meta, self.meta_type)
|
||||||
|
|
||||||
|
ptrs = (ctypes.c_void_p * data.ncols)()
|
||||||
|
if hasattr(data, "internal") and hasattr(data.internal, "column"):
|
||||||
|
# datatable>0.8.0
|
||||||
|
for icol in range(data.ncols):
|
||||||
|
col = data.internal.column(icol)
|
||||||
|
ptr = col.data_pointer
|
||||||
|
ptrs[icol] = ctypes.c_void_p(ptr)
|
||||||
|
else:
|
||||||
|
# datatable<=0.8.0
|
||||||
|
from datatable.internal import \
|
||||||
|
frame_column_data_r # pylint: disable=no-name-in-module,import-error
|
||||||
|
for icol in range(data.ncols):
|
||||||
|
ptrs[icol] = frame_column_data_r(data, icol)
|
||||||
|
|
||||||
|
# always return stypes for dt ingestion
|
||||||
|
feature_type_strings = (ctypes.c_char_p * data.ncols)()
|
||||||
|
for icol in range(data.ncols):
|
||||||
|
feature_type_strings[icol] = ctypes.c_char_p(
|
||||||
|
data.stypes[icol].name.encode('utf-8'))
|
||||||
|
|
||||||
|
self._warn_unused_missing(data)
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
_check_call(_LIB.XGDMatrixCreateFromDT(
|
||||||
|
ptrs, feature_type_strings,
|
||||||
|
c_bst_ulong(data.shape[0]),
|
||||||
|
c_bst_ulong(data.shape[1]),
|
||||||
|
ctypes.byref(handle),
|
||||||
|
ctypes.c_int(self.nthread)))
|
||||||
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry.register_handler('datatable', 'Frame', DTHandler)
|
||||||
|
__dmatrix_registry.register_handler('datatable', 'DataTable', DTHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class CudaArrayInterfaceHandler(DataHandler):
|
||||||
|
'''Handler of data with `__cuda_array_interface__` (cupy.ndarray).'''
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
"""Initialize DMatrix from cupy ndarray."""
|
||||||
|
interface = data.__cuda_array_interface__
|
||||||
|
if 'mask' in interface:
|
||||||
|
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||||
|
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||||
|
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
_check_call(
|
||||||
|
_LIB.XGDMatrixCreateFromArrayInterface(
|
||||||
|
interface_str,
|
||||||
|
ctypes.c_float(self.missing),
|
||||||
|
ctypes.c_int(self.nthread),
|
||||||
|
ctypes.byref(handle)))
|
||||||
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry.register_handler('cupy.core.core', 'ndarray',
|
||||||
|
CudaArrayInterfaceHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class CudaColumnarHandler(DataHandler):
|
||||||
|
'''Handler of CUDA based columnar data. (cudf.DataFrame)'''
|
||||||
|
def _maybe_cudf_dataframe(self, data, feature_names, feature_types):
|
||||||
|
"""Extract internal data from cudf.DataFrame for DMatrix data."""
|
||||||
|
if feature_names is None:
|
||||||
|
if lazy_isinstance(data, 'cudf.core.series', 'Series'):
|
||||||
|
feature_names = [data.name]
|
||||||
|
elif lazy_isinstance(
|
||||||
|
data.columns, 'cudf.core.multiindex', 'MultiIndex'):
|
||||||
|
feature_names = [
|
||||||
|
' '.join([str(x) for x in i])
|
||||||
|
for i in data.columns
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
feature_names = data.columns.format()
|
||||||
|
if feature_types is None:
|
||||||
|
if lazy_isinstance(data, 'cudf.core.series', 'Series'):
|
||||||
|
dtypes = [data.dtype]
|
||||||
|
else:
|
||||||
|
dtypes = data.dtypes
|
||||||
|
feature_types = [PandasHandler.pandas_dtype_mapper[d.name]
|
||||||
|
for d in dtypes]
|
||||||
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
return self._maybe_cudf_dataframe(data, None, None)
|
||||||
|
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
"""Initialize DMatrix from columnar memory format."""
|
||||||
|
data, feature_names, feature_types = self._maybe_cudf_dataframe(
|
||||||
|
data, feature_names, feature_types)
|
||||||
|
interfaces_str = _cudf_array_interfaces(data)
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
_check_call(
|
||||||
|
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
|
||||||
|
interfaces_str,
|
||||||
|
ctypes.c_float(self.missing),
|
||||||
|
ctypes.c_int(self.nthread),
|
||||||
|
ctypes.byref(handle)))
|
||||||
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry.register_handler('cudf.core.dataframe', 'DataFrame',
|
||||||
|
CudaColumnarHandler)
|
||||||
|
__dmatrix_registry.register_handler('cudf.core.series', 'Series',
|
||||||
|
CudaColumnarHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class DLPackHandler(CudaArrayInterfaceHandler):
|
||||||
|
'''Handler of `dlpack`.'''
|
||||||
|
def _maybe_dlpack_data(self, data, feature_names, feature_types):
|
||||||
|
from cupy import fromDlpack # pylint: disable=E0401
|
||||||
|
data = fromDlpack(data)
|
||||||
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
return self._maybe_dlpack_data(data, None, None)
|
||||||
|
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
data, feature_names, feature_types = self._maybe_dlpack_data(
|
||||||
|
data, feature_names, feature_types)
|
||||||
|
return super().handle_input(
|
||||||
|
data, feature_names, feature_types)
|
||||||
|
|
||||||
|
|
||||||
|
__dmatrix_registry.register_handler_opaque(
|
||||||
|
lambda x: 'PyCapsule' in str(type(x)) and "dltensor" in str(x),
|
||||||
|
DLPackHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceQuantileDMatrixDataHandler(DataHandler): # pylint: disable=abstract-method
|
||||||
|
'''Base class of data handler for `DeviceQuantileDMatrix`.'''
|
||||||
|
def __init__(self, max_bin, missing, nthread, silent,
|
||||||
|
meta=None, meta_type=None):
|
||||||
|
self.max_bin = max_bin
|
||||||
|
super().__init__(missing, nthread, silent, meta, meta_type)
|
||||||
|
|
||||||
|
|
||||||
|
__device_quantile_dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_quantile_dmatrix_data_handler(
|
||||||
|
data, max_bin, missing, nthread, silent):
|
||||||
|
'''Get data handler for `DeviceQuantileDMatrix`. Similar to
|
||||||
|
`get_dmatrix_data_handler`.
|
||||||
|
|
||||||
|
.. versionadded:: 1.2.0
|
||||||
|
|
||||||
|
'''
|
||||||
|
handler = __device_quantile_dmatrix_registry.get_handler(
|
||||||
|
data)
|
||||||
|
assert handler, 'Current data type ' + str(type(data)) +\
|
||||||
|
' is not supported for DeviceQuantileDMatrix'
|
||||||
|
return handler(max_bin, missing, nthread, silent)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceQuantileCudaArrayInterfaceHandler(
|
||||||
|
DeviceQuantileDMatrixDataHandler):
|
||||||
|
'''Handler of data with `__cuda_array_interface__`, for
|
||||||
|
`DeviceQuantileDMatrix`.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
"""Initialize DMatrix from cupy ndarray."""
|
||||||
|
if not hasattr(data, '__cuda_array_interface__') and hasattr(
|
||||||
|
data, '__array__'):
|
||||||
|
import cupy # pylint: disable=import-error
|
||||||
|
data = cupy.array(data, copy=False)
|
||||||
|
|
||||||
|
interface = data.__cuda_array_interface__
|
||||||
|
if 'mask' in interface:
|
||||||
|
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||||
|
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||||
|
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
_check_call(
|
||||||
|
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
|
||||||
|
interface_str,
|
||||||
|
ctypes.c_float(self.missing), ctypes.c_int(self.nthread),
|
||||||
|
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
|
||||||
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
__device_quantile_dmatrix_registry.register_handler(
|
||||||
|
'cupy.core.core', 'ndarray', DeviceQuantileCudaArrayInterfaceHandler)
|
||||||
|
__device_quantile_dmatrix_registry.register_handler_opaque(
|
||||||
|
lambda x: hasattr(x, '__array__'), NumpyHandler)
|
||||||
|
__device_quantile_dmatrix_registry.register_handler_opaque(
|
||||||
|
lambda x: hasattr(x, '__cuda_array_interface__'), NumpyHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceQuantileCudaColumnarHandler(DeviceQuantileDMatrixDataHandler,
|
||||||
|
CudaColumnarHandler):
|
||||||
|
'''Handler of CUDA based columnar data, for `DeviceQuantileDMatrix`.'''
|
||||||
|
def __init__(self, max_bin, missing, nthread, silent,
|
||||||
|
meta=None, meta_type=None):
|
||||||
|
super().__init__(
|
||||||
|
max_bin=max_bin, missing=missing, nthread=nthread, silent=silent,
|
||||||
|
meta=meta, meta_type=meta_type
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
"""Initialize Quantile Device DMatrix from columnar memory format."""
|
||||||
|
data, feature_names, feature_types = self._maybe_cudf_dataframe(
|
||||||
|
data, feature_names, feature_types)
|
||||||
|
interfaces_str = _cudf_array_interfaces(data)
|
||||||
|
handle = ctypes.c_void_p()
|
||||||
|
_check_call(
|
||||||
|
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
|
||||||
|
interfaces_str,
|
||||||
|
ctypes.c_float(self.missing), ctypes.c_int(self.nthread),
|
||||||
|
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
|
||||||
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
__device_quantile_dmatrix_registry.register_handler(
|
||||||
|
'cudf.core.dataframe', 'DataFrame', DeviceQuantileCudaColumnarHandler)
|
||||||
|
__device_quantile_dmatrix_registry.register_handler(
|
||||||
|
'cudf.core.series', 'Series', DeviceQuantileCudaColumnarHandler)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceQuantileDLPackHandler(DeviceQuantileCudaArrayInterfaceHandler,
|
||||||
|
DLPackHandler):
|
||||||
|
'''Handler of `dlpack`, for `DeviceQuantileDMatrix`.'''
|
||||||
|
def __init__(self, max_bin, missing, nthread, silent,
|
||||||
|
meta=None, meta_type=None):
|
||||||
|
super().__init__(
|
||||||
|
max_bin=max_bin, missing=missing, nthread=nthread, silent=silent,
|
||||||
|
meta=meta, meta_type=meta_type
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_input(self, data, feature_names, feature_types):
|
||||||
|
data, feature_names, feature_types = self._maybe_dlpack_data(
|
||||||
|
data, feature_names, feature_types)
|
||||||
|
return super().handle_input(
|
||||||
|
data, feature_names, feature_types)
|
||||||
|
|
||||||
|
|
||||||
|
__device_quantile_dmatrix_registry.register_handler_opaque(
|
||||||
|
lambda x: 'PyCapsule' in str(type(x)) and "dltensor" in str(x),
|
||||||
|
DeviceQuantileDLPackHandler)
|
||||||
@ -246,7 +246,7 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def _more_tags(self):
|
def _more_tags(self):
|
||||||
'''Tags used for scikit-learn data validation.'''
|
'''Tags used for scikit-learn data validation.'''
|
||||||
return {'allow_nan': True}
|
return {'allow_nan': True, 'no_validation': True}
|
||||||
|
|
||||||
def get_booster(self):
|
def get_booster(self):
|
||||||
"""Get the underlying xgboost Booster of this model.
|
"""Get the underlying xgboost Booster of this model.
|
||||||
@ -258,7 +258,8 @@ class XGBModel(XGBModelBase):
|
|||||||
booster : a xgboost booster of underlying model
|
booster : a xgboost booster of underlying model
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, '_Booster'):
|
if not hasattr(self, '_Booster'):
|
||||||
raise XGBoostError('need to call fit or load_model beforehand')
|
from sklearn.exceptions import NotFittedError
|
||||||
|
raise NotFittedError('need to call fit or load_model beforehand')
|
||||||
return self._Booster
|
return self._Booster
|
||||||
|
|
||||||
def set_params(self, **params):
|
def set_params(self, **params):
|
||||||
@ -332,7 +333,7 @@ class XGBModel(XGBModelBase):
|
|||||||
for k, v in internal.items():
|
for k, v in internal.items():
|
||||||
if k in params.keys() and params[k] is None:
|
if k in params.keys() and params[k] is None:
|
||||||
params[k] = parse_parameter(v)
|
params[k] = parse_parameter(v)
|
||||||
except XGBoostError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
return params
|
return params
|
||||||
|
|
||||||
@ -536,12 +537,16 @@ class XGBModel(XGBModelBase):
|
|||||||
else:
|
else:
|
||||||
params.update({'eval_metric': eval_metric})
|
params.update({'eval_metric': eval_metric})
|
||||||
|
|
||||||
self._Booster = train(params, train_dmatrix,
|
try:
|
||||||
self.get_num_boosting_rounds(), evals=evals,
|
self._Booster = train(params, train_dmatrix,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
self.get_num_boosting_rounds(), evals=evals,
|
||||||
evals_result=evals_result, obj=obj, feval=feval,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
verbose_eval=verbose, xgb_model=xgb_model,
|
evals_result=evals_result,
|
||||||
callbacks=callbacks)
|
obj=obj, feval=feval,
|
||||||
|
verbose_eval=verbose, xgb_model=xgb_model,
|
||||||
|
callbacks=callbacks)
|
||||||
|
except XGBoostError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
if evals_result:
|
if evals_result:
|
||||||
for val in evals_result.items():
|
for val in evals_result.items():
|
||||||
@ -1225,13 +1230,16 @@ class XGBRanker(XGBModel):
|
|||||||
'Custom evaluation metric is not yet supported for XGBRanker.')
|
'Custom evaluation metric is not yet supported for XGBRanker.')
|
||||||
params.update({'eval_metric': eval_metric})
|
params.update({'eval_metric': eval_metric})
|
||||||
|
|
||||||
self._Booster = train(params, train_dmatrix,
|
try:
|
||||||
self.n_estimators,
|
self._Booster = train(params, train_dmatrix,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
self.n_estimators,
|
||||||
evals=evals,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
evals_result=evals_result, feval=feval,
|
evals=evals,
|
||||||
verbose_eval=verbose, xgb_model=xgb_model,
|
evals_result=evals_result, feval=feval,
|
||||||
callbacks=callbacks)
|
verbose_eval=verbose, xgb_model=xgb_model,
|
||||||
|
callbacks=callbacks)
|
||||||
|
except XGBoostError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
self.objective = params["objective"]
|
self.objective = params["objective"]
|
||||||
|
|
||||||
|
|||||||
@ -22,6 +22,16 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
|||||||
dtrain = DMatrixT(X, missing=missing, label=y)
|
dtrain = DMatrixT(X, missing=missing, label=y)
|
||||||
assert dtrain.num_col() == kCols
|
assert dtrain.num_col() == kCols
|
||||||
assert dtrain.num_row() == kRows
|
assert dtrain.num_row() == kRows
|
||||||
|
|
||||||
|
if DMatrixT is xgb.DeviceQuantileDMatrix:
|
||||||
|
# Slice is not supported by DeviceQuantileDMatrix
|
||||||
|
with pytest.raises(xgb.core.XGBoostError):
|
||||||
|
dtrain.slice(rindex=[0, 1, 2])
|
||||||
|
dtrain.slice(rindex=[0, 1, 2])
|
||||||
|
else:
|
||||||
|
dtrain.slice(rindex=[0, 1, 2])
|
||||||
|
dtrain.slice(rindex=[0, 1, 2])
|
||||||
|
|
||||||
return dtrain
|
return dtrain
|
||||||
|
|
||||||
|
|
||||||
@ -41,7 +51,7 @@ def _test_from_cupy(DMatrixT):
|
|||||||
|
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
X = cp.random.randn(2, 2, dtype="float32")
|
X = cp.random.randn(2, 2, dtype="float32")
|
||||||
dtrain = DMatrixT(X, label=X)
|
DMatrixT(X, label=X)
|
||||||
|
|
||||||
|
|
||||||
def _test_cupy_training(DMatrixT):
|
def _test_cupy_training(DMatrixT):
|
||||||
@ -88,11 +98,14 @@ def _test_cupy_metainfo(DMatrixT):
|
|||||||
dmat_cupy.set_interface_info('group', cupy_uints)
|
dmat_cupy.set_interface_info('group', cupy_uints)
|
||||||
|
|
||||||
# Test setting info with cupy
|
# Test setting info with cupy
|
||||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cupy.get_float_info('weight'))
|
assert np.array_equal(dmat.get_float_info('weight'),
|
||||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cupy.get_float_info('label'))
|
dmat_cupy.get_float_info('weight'))
|
||||||
|
assert np.array_equal(dmat.get_float_info('label'),
|
||||||
|
dmat_cupy.get_float_info('label'))
|
||||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||||
dmat_cupy.get_float_info('base_margin'))
|
dmat_cupy.get_float_info('base_margin'))
|
||||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))
|
assert np.array_equal(dmat.get_uint_info('group_ptr'),
|
||||||
|
dmat_cupy.get_uint_info('group_ptr'))
|
||||||
|
|
||||||
|
|
||||||
class TestFromCupy:
|
class TestFromCupy:
|
||||||
@ -135,7 +148,9 @@ Arrow specification.'''
|
|||||||
import cupy as cp
|
import cupy as cp
|
||||||
n = 100
|
n = 100
|
||||||
X = cp.random.random((n, 2))
|
X = cp.random.random((n, 2))
|
||||||
xgb.DeviceQuantileDMatrix(X.toDlpack())
|
m = xgb.DeviceQuantileDMatrix(X.toDlpack())
|
||||||
|
with pytest.raises(xgb.core.XGBoostError):
|
||||||
|
m.slice(rindex=[0, 1, 2])
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
@pytest.mark.mgpu
|
@pytest.mark.mgpu
|
||||||
|
|||||||
@ -67,7 +67,8 @@ class TestPandas(unittest.TestCase):
|
|||||||
# 0 1 1 0 0
|
# 0 1 1 0 0
|
||||||
# 1 2 0 1 0
|
# 1 2 0 1 0
|
||||||
# 2 3 0 0 1
|
# 2 3 0 0 1
|
||||||
result, _, _ = xgb.core._maybe_pandas_data(dummies, None, None)
|
pandas_handler = xgb.data.PandasHandler(np.nan, 0, False)
|
||||||
|
result, _, _ = pandas_handler._maybe_pandas_data(dummies, None, None)
|
||||||
exp = np.array([[1., 1., 0., 0.],
|
exp = np.array([[1., 1., 0., 0.],
|
||||||
[2., 0., 1., 0.],
|
[2., 0., 1., 0.],
|
||||||
[3., 0., 0., 1.]])
|
[3., 0., 0., 1.]])
|
||||||
@ -113,12 +114,12 @@ class TestPandas(unittest.TestCase):
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
rows = 100
|
rows = 100
|
||||||
X = pd.DataFrame(
|
X = pd.DataFrame(
|
||||||
{"A": pd.SparseArray(np.random.randint(0, 10, size=rows)),
|
{"A": pd.arrays.SparseArray(np.random.randint(0, 10, size=rows)),
|
||||||
"B": pd.SparseArray(np.random.randn(rows)),
|
"B": pd.arrays.SparseArray(np.random.randn(rows)),
|
||||||
"C": pd.SparseArray(np.random.permutation(
|
"C": pd.arrays.SparseArray(np.random.permutation(
|
||||||
[True, False] * (rows // 2)))}
|
[True, False] * (rows // 2)))}
|
||||||
)
|
)
|
||||||
y = pd.Series(pd.SparseArray(np.random.randn(rows)))
|
y = pd.Series(pd.arrays.SparseArray(np.random.randn(rows)))
|
||||||
dtrain = xgb.DMatrix(X, y)
|
dtrain = xgb.DMatrix(X, y)
|
||||||
booster = xgb.train({}, dtrain, num_boost_round=4)
|
booster = xgb.train({}, dtrain, num_boost_round=4)
|
||||||
predt_sparse = booster.predict(xgb.DMatrix(X))
|
predt_sparse = booster.predict(xgb.DMatrix(X))
|
||||||
@ -128,17 +129,18 @@ class TestPandas(unittest.TestCase):
|
|||||||
def test_pandas_label(self):
|
def test_pandas_label(self):
|
||||||
# label must be a single column
|
# label must be a single column
|
||||||
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||||
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df,
|
pandas_handler = xgb.data.PandasHandler(np.nan, 0, False)
|
||||||
|
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
|
||||||
None, None, 'label', 'float')
|
None, None, 'label', 'float')
|
||||||
|
|
||||||
# label must be supported dtype
|
# label must be supported dtype
|
||||||
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
|
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
|
||||||
self.assertRaises(ValueError, xgb.core._maybe_pandas_data, df,
|
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
|
||||||
None, None, 'label', 'float')
|
None, None, 'label', 'float')
|
||||||
|
|
||||||
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
|
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
|
||||||
result, _, _ = xgb.core._maybe_pandas_data(df, None, None,
|
result, _, _ = pandas_handler._maybe_pandas_data(df, None, None,
|
||||||
'label', 'float')
|
'label', 'float')
|
||||||
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
|
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
|
||||||
dtype=float))
|
dtype=float))
|
||||||
dm = xgb.DMatrix(np.random.randn(3, 2), label=df)
|
dm = xgb.DMatrix(np.random.randn(3, 2), label=df)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user