Simplify the data backends. (#5893)

This commit is contained in:
Jiaming Yuan 2020-07-16 15:17:31 +08:00 committed by GitHub
parent 7aee0e51ed
commit 029a8b533f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 790 additions and 806 deletions

View File

@ -4,13 +4,14 @@
"""Core XGBoost Library.""" """Core XGBoost Library."""
import collections import collections
# pylint: disable=no-name-in-module,import-error # pylint: disable=no-name-in-module,import-error
from collections.abc import Mapping # Python 3 from collections.abc import Mapping
# pylint: enable=no-name-in-module,import-error # pylint: enable=no-name-in-module,import-error
import ctypes import ctypes
import os import os
import re import re
import sys import sys
import json import json
import warnings
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
@ -267,7 +268,6 @@ def _convert_unknown_data(data, meta=None, meta_type=None):
raise TypeError('Can not handle data from {}'.format( raise TypeError('Can not handle data from {}'.format(
type(data).__name__)) from e type(data).__name__)) from e
else: else:
import warnings
warnings.warn( warnings.warn(
'Unknown data type: ' + str(type(data)) + 'Unknown data type: ' + str(type(data)) +
', coverting it to csr_matrix') ', coverting it to csr_matrix')
@ -279,27 +279,6 @@ def _convert_unknown_data(data, meta=None, meta_type=None):
return data return data
# Either object has cuda array interface or contains columns with interfaces
def _has_cuda_array_interface(data):
return hasattr(data, '__cuda_array_interface__') or \
lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame')
def _cudf_array_interfaces(df):
'''Extract CuDF __cuda_array_interface__'''
interfaces = []
if lazy_isinstance(df, 'cudf.core.series', 'Series'):
interfaces.append(df.__cuda_array_interface__)
else:
for col in df:
interface = df[col].__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interfaces.append(interface)
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
return interfaces_str
class DataIter: class DataIter:
'''The interface for user defined data iterator. Currently is only '''The interface for user defined data iterator. Currently is only
supported by Device DMatrix. supported by Device DMatrix.
@ -331,7 +310,7 @@ class DataIter:
'''A wrapper for user defined `next` function. '''A wrapper for user defined `next` function.
`this` is not used in Python. ctypes can handle `self` of a Python `this` is not used in Python. ctypes can handle `self` of a Python
member function automatically when converting a it to c function member function automatically when converting it to c function
pointer. pointer.
''' '''
@ -340,32 +319,30 @@ class DataIter:
def data_handle(data, label=None, weight=None, base_margin=None, def data_handle(data, label=None, weight=None, base_margin=None,
group=None, group=None,
label_lower_bound=None, label_upper_bound=None): label_lower_bound=None, label_upper_bound=None,
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): feature_names=None, feature_types=None):
# pylint: disable=protected-access from .data import dispatch_device_quantile_dmatrix_set_data
self.proxy._set_data_from_cuda_columnar(data) from .data import _device_quantile_transform
elif lazy_isinstance(data, 'cudf.core.series', 'Series'): data, feature_names, feature_types = _device_quantile_transform(
# pylint: disable=protected-access data, feature_names, feature_types
self.proxy._set_data_from_cuda_columnar(data) )
elif lazy_isinstance(data, 'cupy.core.core', 'ndarray'): dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
# pylint: disable=protected-access
self.proxy._set_data_from_cuda_interface(data)
else:
raise TypeError(
'Value type is not supported for data iterator:' +
str(type(self._handle)), type(data))
self.proxy.set_info(label=label, weight=weight, self.proxy.set_info(label=label, weight=weight,
base_margin=base_margin, base_margin=base_margin,
group=group, group=group,
label_lower_bound=label_lower_bound, label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound) label_upper_bound=label_upper_bound,
feature_names=feature_names,
feature_types=feature_types)
try: try:
# Deffer the exception in order to return 0 and stop the iteration. # Differ the exception in order to return 0 and stop the iteration.
# Exception inside a ctype callback function has no effect except # Exception inside a ctype callback function has no effect except
# for printing to stderr (doesn't stop the execution). # for printing to stderr (doesn't stop the execution).
ret = self.next(data_handle) # pylint: disable=not-callable ret = self.next(data_handle) # pylint: disable=not-callable
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
tb = sys.exc_info()[2] tb = sys.exc_info()[2]
# On dask the worker is restarted and somehow the information is
# lost.
self.exception = e.with_traceback(tb) self.exception = e.with_traceback(tb)
return 0 return 0
return ret return ret
@ -453,41 +430,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.handle = None self.handle = None
return return
handler = self._get_data_handler(data) from .data import dispatch_data_backend
can_handle_meta = False handle, feature_names, feature_types = dispatch_data_backend(
if handler is None: data, missing=self.missing,
data = _convert_unknown_data(data, None) threads=self.nthread,
handler = self._get_data_handler(data) feature_names=feature_names,
try: feature_types=feature_types)
handler.handle_meta(label, weight, base_margin) assert handle is not None
can_handle_meta = True self.handle = handle
except NotImplementedError:
can_handle_meta = False
self.handle, feature_names, feature_types = handler.handle_input( self.set_info(label=label, weight=weight, base_margin=base_margin)
data, feature_names, feature_types)
assert self.handle, 'Failed to construct a DMatrix.'
if not can_handle_meta:
self.set_info(label, weight, base_margin)
self.feature_names = feature_names self.feature_names = feature_names
self.feature_types = feature_types self.feature_types = feature_types
def _get_data_handler(self, data, meta=None, meta_type=None):
'''Get data handler for this DMatrix class.'''
from .data import get_dmatrix_data_handler
handler = get_dmatrix_data_handler(
data, self.missing, self.nthread, self.silent, meta, meta_type)
return handler
# pylint: disable=no-self-use
def _get_meta_handler(self, data, meta, meta_type):
from .data import get_dmatrix_meta_handler
handler = get_dmatrix_meta_handler(
data, meta, meta_type)
return handler
def __del__(self): def __del__(self):
if hasattr(self, "handle") and self.handle: if hasattr(self, "handle") and self.handle:
_check_call(_LIB.XGDMatrixFree(self.handle)) _check_call(_LIB.XGDMatrixFree(self.handle))
@ -497,7 +453,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
label=None, weight=None, base_margin=None, label=None, weight=None, base_margin=None,
group=None, group=None,
label_lower_bound=None, label_lower_bound=None,
label_upper_bound=None): label_upper_bound=None,
feature_names=None,
feature_types=None):
'''Set meta info for DMatrix.''' '''Set meta info for DMatrix.'''
if label is not None: if label is not None:
self.set_label(label) self.set_label(label)
@ -511,6 +469,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.set_float_info('label_lower_bound', label_lower_bound) self.set_float_info('label_lower_bound', label_lower_bound)
if label_upper_bound is not None: if label_upper_bound is not None:
self.set_float_info('label_upper_bound', label_upper_bound) self.set_float_info('label_upper_bound', label_upper_bound)
if feature_names is not None:
self.feature_names = feature_names
if feature_types is not None:
self.feature_types = feature_types
def get_float_info(self, field): def get_float_info(self, field):
"""Get float property from the DMatrix. """Get float property from the DMatrix.
@ -565,17 +527,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
data: numpy array data: numpy array
The array of data to be set The array of data to be set
""" """
if isinstance(data, np.ndarray): from .data import dispatch_meta_backend
self.set_float_info_npy2d(field, data) dispatch_meta_backend(self, data, field, 'float')
return
handler = self._get_data_handler(data, field, np.float32)
assert handler
data, _, _ = handler.transform(data)
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
def set_float_info_npy2d(self, field, data): def set_float_info_npy2d(self, field, data):
"""Set float type property into the DMatrix """Set float type property into the DMatrix
@ -589,13 +542,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
data: numpy array data: numpy array
The array of data to be set The array of data to be set
""" """
data, _, _ = self._get_meta_handler( from .data import dispatch_meta_backend
data, field, np.float32).transform(data) dispatch_meta_backend(self, data, field, 'float')
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
def set_uint_info(self, field, data): def set_uint_info(self, field, data):
"""Set uint type property into the DMatrix. """Set uint type property into the DMatrix.
@ -608,27 +556,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
data: numpy array data: numpy array
The array of data to be set The array of data to be set
""" """
data, _, _ = self._get_data_handler( from .data import dispatch_meta_backend
data, field, 'uint32').transform(data) dispatch_meta_backend(self, data, field, 'uint32')
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
def set_interface_info(self, field, data):
"""Set info type property into DMatrix."""
# If we are passed a dataframe, extract the series
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
if len(data.columns) != 1:
raise ValueError(
'Expecting meta-info to contain a single column')
data = data[data.columns[0]]
interface = bytes(json.dumps([data.__cuda_array_interface__],
indent=2), 'utf-8')
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
c_str(field),
interface))
def save_binary(self, fname, silent=True): def save_binary(self, fname, silent=True):
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded """Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
@ -653,10 +582,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
label: array like label: array like
The label information to be set into DMatrix The label information to be set into DMatrix
""" """
if _has_cuda_array_interface(label): from .data import dispatch_meta_backend
self.set_interface_info('label', label) dispatch_meta_backend(self, label, 'label', 'float')
else:
self.set_float_info('label', label)
def set_weight(self, weight): def set_weight(self, weight):
"""Set weight of each instance. """Set weight of each instance.
@ -674,10 +601,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
sense to assign weights to individual data points. sense to assign weights to individual data points.
""" """
if _has_cuda_array_interface(weight): from .data import dispatch_meta_backend
self.set_interface_info('weight', weight) dispatch_meta_backend(self, weight, 'weight', 'float')
else:
self.set_float_info('weight', weight)
def set_base_margin(self, margin): def set_base_margin(self, margin):
"""Set base margin of booster to start from. """Set base margin of booster to start from.
@ -693,10 +618,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
Prediction margin of each datapoint Prediction margin of each datapoint
""" """
if _has_cuda_array_interface(margin): from .data import dispatch_meta_backend
self.set_interface_info('base_margin', margin) dispatch_meta_backend(self, margin, 'base_margin', 'float')
else:
self.set_float_info('base_margin', margin)
def set_group(self, group): def set_group(self, group):
"""Set group size of DMatrix (used for ranking). """Set group size of DMatrix (used for ranking).
@ -706,10 +629,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
group : array like group : array like
Group size of each group Group size of each group
""" """
if _has_cuda_array_interface(group): from .data import dispatch_meta_backend
self.set_interface_info('group', group) dispatch_meta_backend(self, group, 'group', 'uint32')
else:
self.set_uint_info('group', group)
def get_label(self): def get_label(self):
"""Get the label of the DMatrix. """Get the label of the DMatrix.
@ -830,7 +751,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if len(feature_names) != len(set(feature_names)): if len(feature_names) != len(set(feature_names)):
raise ValueError('feature_names must be unique') raise ValueError('feature_names must be unique')
if len(feature_names) != self.num_col(): if len(feature_names) != self.num_col() and self.num_col() != 0:
msg = 'feature_names must have the same length as data' msg = 'feature_names must have the same length as data'
raise ValueError(msg) raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. []< # prohibit to use symbols may affect to parse. e.g. []<
@ -935,28 +856,35 @@ class DeviceQuantileDMatrix(DMatrix):
""" """
def __init__(self, data, label=None, weight=None, base_margin=None, def __init__(self, data, label=None, weight=None, # pylint: disable=W0231
base_margin=None,
missing=None, missing=None,
silent=False, silent=False,
feature_names=None, feature_names=None,
feature_types=None, feature_types=None,
nthread=None, max_bin=256): nthread=None, max_bin=256):
self.max_bin = max_bin self.max_bin = max_bin
self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else 1
if isinstance(data, ctypes.c_void_p): if isinstance(data, ctypes.c_void_p):
self.handle = data self.handle = data
return return
super().__init__(data, label=label, weight=weight, from .data import init_device_quantile_dmatrix
handle, feature_names, feature_types = init_device_quantile_dmatrix(
data, missing=self.missing, threads=self.nthread,
max_bin=self.max_bin,
label=label, weight=weight,
base_margin=base_margin, base_margin=base_margin,
missing=missing, group=None,
silent=silent, label_lower_bound=None,
label_upper_bound=None,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types)
nthread=nthread) self.handle = handle
def _get_data_handler(self, data, meta=None, meta_type=None): self.feature_names = feature_names
from .data import get_device_quantile_dmatrix_data_handler self.feature_types = feature_types
return get_device_quantile_dmatrix_data_handler(
data, self.max_bin, self.missing, self.nthread, self.silent)
def _set_data_from_cuda_interface(self, data): def _set_data_from_cuda_interface(self, data):
'''Set data from CUDA array interface.''' '''Set data from CUDA array interface.'''
@ -971,6 +899,7 @@ class DeviceQuantileDMatrix(DMatrix):
def _set_data_from_cuda_columnar(self, data): def _set_data_from_cuda_columnar(self, data):
'''Set data from CUDA columnar format.1''' '''Set data from CUDA columnar format.1'''
from .data import _cudf_array_interfaces
interfaces_str = _cudf_array_interfaces(data) interfaces_str = _cudf_array_interfaces(data)
_check_call( _check_call(
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar( _LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
@ -1592,6 +1521,7 @@ class Booster(object):
rows = data.shape[0] rows = data.shape[0]
return reshape_output(mem, rows) return reshape_output(mem, rows)
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
from .data import _cudf_array_interfaces
interfaces_str = _cudf_array_interfaces(data) interfaces_str = _cudf_array_interfaces(data)
_check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns( _check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns(
self.handle, self.handle,

File diff suppressed because it is too large Load Diff

View File

@ -12,12 +12,16 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
auto const& value = adapter->Value(); auto const& value = adapter->Value();
this->batch_ = adapter; this->batch_ = adapter;
device_ = adapter->DeviceIdx(); device_ = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
} }
void DMatrixProxy::FromCudaArray(std::string interface_str) { void DMatrixProxy::FromCudaArray(std::string interface_str) {
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str)); std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str));
this->batch_ = adapter; this->batch_ = adapter;
device_ = adapter->DeviceIdx(); device_ = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
} }
} // namespace data } // namespace data

View File

@ -12,11 +12,12 @@ import testing as tm
class TestDeviceQuantileDMatrix(unittest.TestCase): class TestDeviceQuantileDMatrix(unittest.TestCase):
def test_dmatrix_numpy_init(self): def test_dmatrix_numpy_init(self):
data = np.random.randn(5, 5) data = np.random.randn(5, 5)
with pytest.raises(AssertionError, match='is not supported for DeviceQuantileDMatrix'): with pytest.raises(TypeError,
dm = xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64)) match='is not supported for DeviceQuantileDMatrix'):
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_dmatrix_cupy_init(self): def test_dmatrix_cupy_init(self):
import cupy as cp import cupy as cp
data = cp.random.randn(5, 5) data = cp.random.randn(5, 5)
dm = xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64)) xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))

View File

@ -119,10 +119,10 @@ def _test_cudf_metainfo(DMatrixT):
dmat.set_float_info('label', floats) dmat.set_float_info('label', floats)
dmat.set_float_info('base_margin', floats) dmat.set_float_info('base_margin', floats)
dmat.set_uint_info('group', uints) dmat.set_uint_info('group', uints)
dmat_cudf.set_interface_info('weight', cudf_floats) dmat_cudf.set_info(weight=cudf_floats)
dmat_cudf.set_interface_info('label', cudf_floats) dmat_cudf.set_info(label=cudf_floats)
dmat_cudf.set_interface_info('base_margin', cudf_floats) dmat_cudf.set_info(base_margin=cudf_floats)
dmat_cudf.set_interface_info('group', cudf_uints) dmat_cudf.set_info(group=cudf_uints)
# Test setting info with cudf DataFrame # Test setting info with cudf DataFrame
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight')) assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
@ -132,10 +132,10 @@ def _test_cudf_metainfo(DMatrixT):
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr')) assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
# Test setting info with cudf Series # Test setting info with cudf Series
dmat_cudf.set_interface_info('weight', cudf_floats[cudf_floats.columns[0]]) dmat_cudf.set_info(weight=cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_interface_info('label', cudf_floats[cudf_floats.columns[0]]) dmat_cudf.set_info(label=cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_interface_info('base_margin', cudf_floats[cudf_floats.columns[0]]) dmat_cudf.set_info(base_margin=cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_interface_info('group', cudf_uints[cudf_uints.columns[0]]) dmat_cudf.set_info(group=cudf_uints[cudf_uints.columns[0]])
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight')) assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label')) assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
assert np.array_equal(dmat.get_float_info('base_margin'), assert np.array_equal(dmat.get_float_info('base_margin'),

View File

@ -92,10 +92,10 @@ def _test_cupy_metainfo(DMatrixT):
dmat.set_float_info('label', floats) dmat.set_float_info('label', floats)
dmat.set_float_info('base_margin', floats) dmat.set_float_info('base_margin', floats)
dmat.set_uint_info('group', uints) dmat.set_uint_info('group', uints)
dmat_cupy.set_interface_info('weight', cupy_floats) dmat_cupy.set_info(weight=cupy_floats)
dmat_cupy.set_interface_info('label', cupy_floats) dmat_cupy.set_info(label=cupy_floats)
dmat_cupy.set_interface_info('base_margin', cupy_floats) dmat_cupy.set_info(base_margin=cupy_floats)
dmat_cupy.set_interface_info('group', cupy_uints) dmat_cupy.set_info(group=cupy_uints)
# Test setting info with cupy # Test setting info with cupy
assert np.array_equal(dmat.get_float_info('weight'), assert np.array_equal(dmat.get_float_info('weight'),

View File

@ -1,17 +1,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys import sys
from contextlib import contextmanager from contextlib import contextmanager
try:
# python 2
from StringIO import StringIO
except ImportError:
# python 3
from io import StringIO from io import StringIO
import numpy as np import numpy as np
import os
import xgboost as xgb import xgboost as xgb
import unittest import unittest
import json import json
from pathlib import Path from pathlib import Path
import tempfile
dpath = 'demo/data/' dpath = 'demo/data/'
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
@ -66,13 +63,16 @@ class TestBasic(unittest.TestCase):
# error must be smaller than 10% # error must be smaller than 10%
assert err < 0.1 assert err < 0.1
with tempfile.TemporaryDirectory() as tmpdir:
dtest_path = os.path.join(tmpdir, 'dtest.dmatrix')
# save dmatrix into binary buffer # save dmatrix into binary buffer
dtest.save_binary('dtest.buffer') dtest.save_binary(dtest_path)
# save model # save model
bst.save_model('xgb.model') model_path = os.path.join(tmpdir, 'model.booster')
bst.save_model(model_path)
# load model and data in # load model and data in
bst2 = xgb.Booster(model_file='xgb.model') bst2 = xgb.Booster(model_file=model_path)
dtest2 = xgb.DMatrix('dtest.buffer') dtest2 = xgb.DMatrix(dtest_path)
preds2 = bst2.predict(dtest2) preds2 = bst2.predict(dtest2)
# assert they are the same # assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0 assert np.sum(np.abs(preds2 - preds)) == 0

View File

@ -67,8 +67,7 @@ class TestPandas(unittest.TestCase):
# 0 1 1 0 0 # 0 1 1 0 0
# 1 2 0 1 0 # 1 2 0 1 0
# 2 3 0 0 1 # 2 3 0 0 1
pandas_handler = xgb.data.PandasHandler(np.nan, 0, False) result, _, _ = xgb.data._transform_pandas_df(dummies)
result, _, _ = pandas_handler._maybe_pandas_data(dummies, None, None)
exp = np.array([[1., 1., 0., 0.], exp = np.array([[1., 1., 0., 0.],
[2., 0., 1., 0.], [2., 0., 1., 0.],
[3., 0., 0., 1.]]) [3., 0., 0., 1.]])
@ -129,17 +128,16 @@ class TestPandas(unittest.TestCase):
def test_pandas_label(self): def test_pandas_label(self):
# label must be a single column # label must be a single column
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]}) df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
pandas_handler = xgb.data.PandasHandler(np.nan, 0, False) self.assertRaises(ValueError, xgb.data._transform_pandas_df, df,
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
None, None, 'label', 'float') None, None, 'label', 'float')
# label must be supported dtype # label must be supported dtype
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)}) df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df, self.assertRaises(ValueError, xgb.data._transform_pandas_df, df,
None, None, 'label', 'float') None, None, 'label', 'float')
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)}) df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
result, _, _ = pandas_handler._maybe_pandas_data(df, None, None, result, _, _ = xgb.data._transform_pandas_df(df, None, None,
'label', 'float') 'label', 'float')
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]], np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
dtype=float)) dtype=float))