Simplify the data backends. (#5893)

This commit is contained in:
Jiaming Yuan 2020-07-16 15:17:31 +08:00 committed by GitHub
parent 7aee0e51ed
commit 029a8b533f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 790 additions and 806 deletions

View File

@ -4,13 +4,14 @@
"""Core XGBoost Library."""
import collections
# pylint: disable=no-name-in-module,import-error
from collections.abc import Mapping # Python 3
from collections.abc import Mapping
# pylint: enable=no-name-in-module,import-error
import ctypes
import os
import re
import sys
import json
import warnings
import numpy as np
import scipy.sparse
@ -267,7 +268,6 @@ def _convert_unknown_data(data, meta=None, meta_type=None):
raise TypeError('Can not handle data from {}'.format(
type(data).__name__)) from e
else:
import warnings
warnings.warn(
'Unknown data type: ' + str(type(data)) +
', coverting it to csr_matrix')
@ -279,27 +279,6 @@ def _convert_unknown_data(data, meta=None, meta_type=None):
return data
# Either object has cuda array interface or contains columns with interfaces
def _has_cuda_array_interface(data):
return hasattr(data, '__cuda_array_interface__') or \
lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame')
def _cudf_array_interfaces(df):
'''Extract CuDF __cuda_array_interface__'''
interfaces = []
if lazy_isinstance(df, 'cudf.core.series', 'Series'):
interfaces.append(df.__cuda_array_interface__)
else:
for col in df:
interface = df[col].__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interfaces.append(interface)
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
return interfaces_str
class DataIter:
'''The interface for user defined data iterator. Currently is only
supported by Device DMatrix.
@ -331,7 +310,7 @@ class DataIter:
'''A wrapper for user defined `next` function.
`this` is not used in Python. ctypes can handle `self` of a Python
member function automatically when converting a it to c function
member function automatically when converting it to c function
pointer.
'''
@ -340,32 +319,30 @@ class DataIter:
def data_handle(data, label=None, weight=None, base_margin=None,
group=None,
label_lower_bound=None, label_upper_bound=None):
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
# pylint: disable=protected-access
self.proxy._set_data_from_cuda_columnar(data)
elif lazy_isinstance(data, 'cudf.core.series', 'Series'):
# pylint: disable=protected-access
self.proxy._set_data_from_cuda_columnar(data)
elif lazy_isinstance(data, 'cupy.core.core', 'ndarray'):
# pylint: disable=protected-access
self.proxy._set_data_from_cuda_interface(data)
else:
raise TypeError(
'Value type is not supported for data iterator:' +
str(type(self._handle)), type(data))
label_lower_bound=None, label_upper_bound=None,
feature_names=None, feature_types=None):
from .data import dispatch_device_quantile_dmatrix_set_data
from .data import _device_quantile_transform
data, feature_names, feature_types = _device_quantile_transform(
data, feature_names, feature_types
)
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
self.proxy.set_info(label=label, weight=weight,
base_margin=base_margin,
group=group,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
label_upper_bound=label_upper_bound,
feature_names=feature_names,
feature_types=feature_types)
try:
# Deffer the exception in order to return 0 and stop the iteration.
# Differ the exception in order to return 0 and stop the iteration.
# Exception inside a ctype callback function has no effect except
# for printing to stderr (doesn't stop the execution).
ret = self.next(data_handle) # pylint: disable=not-callable
except Exception as e: # pylint: disable=broad-except
tb = sys.exc_info()[2]
# On dask the worker is restarted and somehow the information is
# lost.
self.exception = e.with_traceback(tb)
return 0
return ret
@ -453,41 +430,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.handle = None
return
handler = self._get_data_handler(data)
can_handle_meta = False
if handler is None:
data = _convert_unknown_data(data, None)
handler = self._get_data_handler(data)
try:
handler.handle_meta(label, weight, base_margin)
can_handle_meta = True
except NotImplementedError:
can_handle_meta = False
from .data import dispatch_data_backend
handle, feature_names, feature_types = dispatch_data_backend(
data, missing=self.missing,
threads=self.nthread,
feature_names=feature_names,
feature_types=feature_types)
assert handle is not None
self.handle = handle
self.handle, feature_names, feature_types = handler.handle_input(
data, feature_names, feature_types)
assert self.handle, 'Failed to construct a DMatrix.'
if not can_handle_meta:
self.set_info(label, weight, base_margin)
self.set_info(label=label, weight=weight, base_margin=base_margin)
self.feature_names = feature_names
self.feature_types = feature_types
def _get_data_handler(self, data, meta=None, meta_type=None):
'''Get data handler for this DMatrix class.'''
from .data import get_dmatrix_data_handler
handler = get_dmatrix_data_handler(
data, self.missing, self.nthread, self.silent, meta, meta_type)
return handler
# pylint: disable=no-self-use
def _get_meta_handler(self, data, meta, meta_type):
from .data import get_dmatrix_meta_handler
handler = get_dmatrix_meta_handler(
data, meta, meta_type)
return handler
def __del__(self):
if hasattr(self, "handle") and self.handle:
_check_call(_LIB.XGDMatrixFree(self.handle))
@ -497,7 +453,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
label=None, weight=None, base_margin=None,
group=None,
label_lower_bound=None,
label_upper_bound=None):
label_upper_bound=None,
feature_names=None,
feature_types=None):
'''Set meta info for DMatrix.'''
if label is not None:
self.set_label(label)
@ -511,6 +469,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.set_float_info('label_lower_bound', label_lower_bound)
if label_upper_bound is not None:
self.set_float_info('label_upper_bound', label_upper_bound)
if feature_names is not None:
self.feature_names = feature_names
if feature_types is not None:
self.feature_types = feature_types
def get_float_info(self, field):
"""Get float property from the DMatrix.
@ -565,17 +527,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
data: numpy array
The array of data to be set
"""
if isinstance(data, np.ndarray):
self.set_float_info_npy2d(field, data)
return
handler = self._get_data_handler(data, field, np.float32)
assert handler
data, _, _ = handler.transform(data)
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
from .data import dispatch_meta_backend
dispatch_meta_backend(self, data, field, 'float')
def set_float_info_npy2d(self, field, data):
"""Set float type property into the DMatrix
@ -589,13 +542,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
data: numpy array
The array of data to be set
"""
data, _, _ = self._get_meta_handler(
data, field, np.float32).transform(data)
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
from .data import dispatch_meta_backend
dispatch_meta_backend(self, data, field, 'float')
def set_uint_info(self, field, data):
"""Set uint type property into the DMatrix.
@ -608,27 +556,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
data: numpy array
The array of data to be set
"""
data, _, _ = self._get_data_handler(
data, field, 'uint32').transform(data)
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
def set_interface_info(self, field, data):
"""Set info type property into DMatrix."""
# If we are passed a dataframe, extract the series
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
if len(data.columns) != 1:
raise ValueError(
'Expecting meta-info to contain a single column')
data = data[data.columns[0]]
interface = bytes(json.dumps([data.__cuda_array_interface__],
indent=2), 'utf-8')
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
c_str(field),
interface))
from .data import dispatch_meta_backend
dispatch_meta_backend(self, data, field, 'uint32')
def save_binary(self, fname, silent=True):
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
@ -653,10 +582,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
label: array like
The label information to be set into DMatrix
"""
if _has_cuda_array_interface(label):
self.set_interface_info('label', label)
else:
self.set_float_info('label', label)
from .data import dispatch_meta_backend
dispatch_meta_backend(self, label, 'label', 'float')
def set_weight(self, weight):
"""Set weight of each instance.
@ -674,10 +601,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
sense to assign weights to individual data points.
"""
if _has_cuda_array_interface(weight):
self.set_interface_info('weight', weight)
else:
self.set_float_info('weight', weight)
from .data import dispatch_meta_backend
dispatch_meta_backend(self, weight, 'weight', 'float')
def set_base_margin(self, margin):
"""Set base margin of booster to start from.
@ -693,10 +618,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
Prediction margin of each datapoint
"""
if _has_cuda_array_interface(margin):
self.set_interface_info('base_margin', margin)
else:
self.set_float_info('base_margin', margin)
from .data import dispatch_meta_backend
dispatch_meta_backend(self, margin, 'base_margin', 'float')
def set_group(self, group):
"""Set group size of DMatrix (used for ranking).
@ -706,10 +629,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
group : array like
Group size of each group
"""
if _has_cuda_array_interface(group):
self.set_interface_info('group', group)
else:
self.set_uint_info('group', group)
from .data import dispatch_meta_backend
dispatch_meta_backend(self, group, 'group', 'uint32')
def get_label(self):
"""Get the label of the DMatrix.
@ -830,7 +751,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if len(feature_names) != len(set(feature_names)):
raise ValueError('feature_names must be unique')
if len(feature_names) != self.num_col():
if len(feature_names) != self.num_col() and self.num_col() != 0:
msg = 'feature_names must have the same length as data'
raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. []<
@ -935,28 +856,35 @@ class DeviceQuantileDMatrix(DMatrix):
"""
def __init__(self, data, label=None, weight=None, base_margin=None,
def __init__(self, data, label=None, weight=None, # pylint: disable=W0231
base_margin=None,
missing=None,
silent=False,
feature_names=None,
feature_types=None,
nthread=None, max_bin=256):
self.max_bin = max_bin
self.missing = missing if missing is not None else np.nan
self.nthread = nthread if nthread is not None else 1
if isinstance(data, ctypes.c_void_p):
self.handle = data
return
super().__init__(data, label=label, weight=weight,
base_margin=base_margin,
missing=missing,
silent=silent,
feature_names=feature_names,
feature_types=feature_types,
nthread=nthread)
from .data import init_device_quantile_dmatrix
handle, feature_names, feature_types = init_device_quantile_dmatrix(
data, missing=self.missing, threads=self.nthread,
max_bin=self.max_bin,
label=label, weight=weight,
base_margin=base_margin,
group=None,
label_lower_bound=None,
label_upper_bound=None,
feature_names=feature_names,
feature_types=feature_types)
self.handle = handle
def _get_data_handler(self, data, meta=None, meta_type=None):
from .data import get_device_quantile_dmatrix_data_handler
return get_device_quantile_dmatrix_data_handler(
data, self.max_bin, self.missing, self.nthread, self.silent)
self.feature_names = feature_names
self.feature_types = feature_types
def _set_data_from_cuda_interface(self, data):
'''Set data from CUDA array interface.'''
@ -971,6 +899,7 @@ class DeviceQuantileDMatrix(DMatrix):
def _set_data_from_cuda_columnar(self, data):
'''Set data from CUDA columnar format.1'''
from .data import _cudf_array_interfaces
interfaces_str = _cudf_array_interfaces(data)
_check_call(
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
@ -1592,6 +1521,7 @@ class Booster(object):
rows = data.shape[0]
return reshape_output(mem, rows)
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
from .data import _cudf_array_interfaces
interfaces_str = _cudf_array_interfaces(data)
_check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns(
self.handle,

File diff suppressed because it is too large Load Diff

View File

@ -12,12 +12,16 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
auto const& value = adapter->Value();
this->batch_ = adapter;
device_ = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
}
void DMatrixProxy::FromCudaArray(std::string interface_str) {
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str));
this->batch_ = adapter;
device_ = adapter->DeviceIdx();
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
}
} // namespace data

View File

@ -12,11 +12,12 @@ import testing as tm
class TestDeviceQuantileDMatrix(unittest.TestCase):
def test_dmatrix_numpy_init(self):
data = np.random.randn(5, 5)
with pytest.raises(AssertionError, match='is not supported for DeviceQuantileDMatrix'):
dm = xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
with pytest.raises(TypeError,
match='is not supported for DeviceQuantileDMatrix'):
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
@pytest.mark.skipif(**tm.no_cupy())
def test_dmatrix_cupy_init(self):
import cupy as cp
data = cp.random.randn(5, 5)
dm = xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))

View File

@ -119,10 +119,10 @@ def _test_cudf_metainfo(DMatrixT):
dmat.set_float_info('label', floats)
dmat.set_float_info('base_margin', floats)
dmat.set_uint_info('group', uints)
dmat_cudf.set_interface_info('weight', cudf_floats)
dmat_cudf.set_interface_info('label', cudf_floats)
dmat_cudf.set_interface_info('base_margin', cudf_floats)
dmat_cudf.set_interface_info('group', cudf_uints)
dmat_cudf.set_info(weight=cudf_floats)
dmat_cudf.set_info(label=cudf_floats)
dmat_cudf.set_info(base_margin=cudf_floats)
dmat_cudf.set_info(group=cudf_uints)
# Test setting info with cudf DataFrame
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
@ -132,10 +132,10 @@ def _test_cudf_metainfo(DMatrixT):
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
# Test setting info with cudf Series
dmat_cudf.set_interface_info('weight', cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_interface_info('label', cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_interface_info('base_margin', cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_interface_info('group', cudf_uints[cudf_uints.columns[0]])
dmat_cudf.set_info(weight=cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_info(label=cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_info(base_margin=cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_info(group=cudf_uints[cudf_uints.columns[0]])
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
assert np.array_equal(dmat.get_float_info('base_margin'),

View File

@ -92,10 +92,10 @@ def _test_cupy_metainfo(DMatrixT):
dmat.set_float_info('label', floats)
dmat.set_float_info('base_margin', floats)
dmat.set_uint_info('group', uints)
dmat_cupy.set_interface_info('weight', cupy_floats)
dmat_cupy.set_interface_info('label', cupy_floats)
dmat_cupy.set_interface_info('base_margin', cupy_floats)
dmat_cupy.set_interface_info('group', cupy_uints)
dmat_cupy.set_info(weight=cupy_floats)
dmat_cupy.set_info(label=cupy_floats)
dmat_cupy.set_info(base_margin=cupy_floats)
dmat_cupy.set_info(group=cupy_uints)
# Test setting info with cupy
assert np.array_equal(dmat.get_float_info('weight'),

View File

@ -1,17 +1,14 @@
# -*- coding: utf-8 -*-
import sys
from contextlib import contextmanager
try:
# python 2
from StringIO import StringIO
except ImportError:
# python 3
from io import StringIO
from io import StringIO
import numpy as np
import os
import xgboost as xgb
import unittest
import json
from pathlib import Path
import tempfile
dpath = 'demo/data/'
rng = np.random.RandomState(1994)
@ -66,16 +63,19 @@ class TestBasic(unittest.TestCase):
# error must be smaller than 10%
assert err < 0.1
# save dmatrix into binary buffer
dtest.save_binary('dtest.buffer')
# save model
bst.save_model('xgb.model')
# load model and data in
bst2 = xgb.Booster(model_file='xgb.model')
dtest2 = xgb.DMatrix('dtest.buffer')
preds2 = bst2.predict(dtest2)
# assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0
with tempfile.TemporaryDirectory() as tmpdir:
dtest_path = os.path.join(tmpdir, 'dtest.dmatrix')
# save dmatrix into binary buffer
dtest.save_binary(dtest_path)
# save model
model_path = os.path.join(tmpdir, 'model.booster')
bst.save_model(model_path)
# load model and data in
bst2 = xgb.Booster(model_file=model_path)
dtest2 = xgb.DMatrix(dtest_path)
preds2 = bst2.predict(dtest2)
# assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0
def test_record_results(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')

View File

@ -67,8 +67,7 @@ class TestPandas(unittest.TestCase):
# 0 1 1 0 0
# 1 2 0 1 0
# 2 3 0 0 1
pandas_handler = xgb.data.PandasHandler(np.nan, 0, False)
result, _, _ = pandas_handler._maybe_pandas_data(dummies, None, None)
result, _, _ = xgb.data._transform_pandas_df(dummies)
exp = np.array([[1., 1., 0., 0.],
[2., 0., 1., 0.],
[3., 0., 0., 1.]])
@ -129,18 +128,17 @@ class TestPandas(unittest.TestCase):
def test_pandas_label(self):
# label must be a single column
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
pandas_handler = xgb.data.PandasHandler(np.nan, 0, False)
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
self.assertRaises(ValueError, xgb.data._transform_pandas_df, df,
None, None, 'label', 'float')
# label must be supported dtype
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
self.assertRaises(ValueError, xgb.data._transform_pandas_df, df,
None, None, 'label', 'float')
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
result, _, _ = pandas_handler._maybe_pandas_data(df, None, None,
'label', 'float')
result, _, _ = xgb.data._transform_pandas_df(df, None, None,
'label', 'float')
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
dtype=float))
dm = xgb.DMatrix(np.random.randn(3, 2), label=df)