Simplify the data backends. (#5893)
This commit is contained in:
parent
7aee0e51ed
commit
029a8b533f
@ -4,13 +4,14 @@
|
|||||||
"""Core XGBoost Library."""
|
"""Core XGBoost Library."""
|
||||||
import collections
|
import collections
|
||||||
# pylint: disable=no-name-in-module,import-error
|
# pylint: disable=no-name-in-module,import-error
|
||||||
from collections.abc import Mapping # Python 3
|
from collections.abc import Mapping
|
||||||
# pylint: enable=no-name-in-module,import-error
|
# pylint: enable=no-name-in-module,import-error
|
||||||
import ctypes
|
import ctypes
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
@ -267,7 +268,6 @@ def _convert_unknown_data(data, meta=None, meta_type=None):
|
|||||||
raise TypeError('Can not handle data from {}'.format(
|
raise TypeError('Can not handle data from {}'.format(
|
||||||
type(data).__name__)) from e
|
type(data).__name__)) from e
|
||||||
else:
|
else:
|
||||||
import warnings
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'Unknown data type: ' + str(type(data)) +
|
'Unknown data type: ' + str(type(data)) +
|
||||||
', coverting it to csr_matrix')
|
', coverting it to csr_matrix')
|
||||||
@ -279,27 +279,6 @@ def _convert_unknown_data(data, meta=None, meta_type=None):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
# Either object has cuda array interface or contains columns with interfaces
|
|
||||||
def _has_cuda_array_interface(data):
|
|
||||||
return hasattr(data, '__cuda_array_interface__') or \
|
|
||||||
lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame')
|
|
||||||
|
|
||||||
|
|
||||||
def _cudf_array_interfaces(df):
|
|
||||||
'''Extract CuDF __cuda_array_interface__'''
|
|
||||||
interfaces = []
|
|
||||||
if lazy_isinstance(df, 'cudf.core.series', 'Series'):
|
|
||||||
interfaces.append(df.__cuda_array_interface__)
|
|
||||||
else:
|
|
||||||
for col in df:
|
|
||||||
interface = df[col].__cuda_array_interface__
|
|
||||||
if 'mask' in interface:
|
|
||||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
|
||||||
interfaces.append(interface)
|
|
||||||
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
|
|
||||||
return interfaces_str
|
|
||||||
|
|
||||||
|
|
||||||
class DataIter:
|
class DataIter:
|
||||||
'''The interface for user defined data iterator. Currently is only
|
'''The interface for user defined data iterator. Currently is only
|
||||||
supported by Device DMatrix.
|
supported by Device DMatrix.
|
||||||
@ -331,7 +310,7 @@ class DataIter:
|
|||||||
'''A wrapper for user defined `next` function.
|
'''A wrapper for user defined `next` function.
|
||||||
|
|
||||||
`this` is not used in Python. ctypes can handle `self` of a Python
|
`this` is not used in Python. ctypes can handle `self` of a Python
|
||||||
member function automatically when converting a it to c function
|
member function automatically when converting it to c function
|
||||||
pointer.
|
pointer.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
@ -340,32 +319,30 @@ class DataIter:
|
|||||||
|
|
||||||
def data_handle(data, label=None, weight=None, base_margin=None,
|
def data_handle(data, label=None, weight=None, base_margin=None,
|
||||||
group=None,
|
group=None,
|
||||||
label_lower_bound=None, label_upper_bound=None):
|
label_lower_bound=None, label_upper_bound=None,
|
||||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
feature_names=None, feature_types=None):
|
||||||
# pylint: disable=protected-access
|
from .data import dispatch_device_quantile_dmatrix_set_data
|
||||||
self.proxy._set_data_from_cuda_columnar(data)
|
from .data import _device_quantile_transform
|
||||||
elif lazy_isinstance(data, 'cudf.core.series', 'Series'):
|
data, feature_names, feature_types = _device_quantile_transform(
|
||||||
# pylint: disable=protected-access
|
data, feature_names, feature_types
|
||||||
self.proxy._set_data_from_cuda_columnar(data)
|
)
|
||||||
elif lazy_isinstance(data, 'cupy.core.core', 'ndarray'):
|
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
||||||
# pylint: disable=protected-access
|
|
||||||
self.proxy._set_data_from_cuda_interface(data)
|
|
||||||
else:
|
|
||||||
raise TypeError(
|
|
||||||
'Value type is not supported for data iterator:' +
|
|
||||||
str(type(self._handle)), type(data))
|
|
||||||
self.proxy.set_info(label=label, weight=weight,
|
self.proxy.set_info(label=label, weight=weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
group=group,
|
group=group,
|
||||||
label_lower_bound=label_lower_bound,
|
label_lower_bound=label_lower_bound,
|
||||||
label_upper_bound=label_upper_bound)
|
label_upper_bound=label_upper_bound,
|
||||||
|
feature_names=feature_names,
|
||||||
|
feature_types=feature_types)
|
||||||
try:
|
try:
|
||||||
# Deffer the exception in order to return 0 and stop the iteration.
|
# Differ the exception in order to return 0 and stop the iteration.
|
||||||
# Exception inside a ctype callback function has no effect except
|
# Exception inside a ctype callback function has no effect except
|
||||||
# for printing to stderr (doesn't stop the execution).
|
# for printing to stderr (doesn't stop the execution).
|
||||||
ret = self.next(data_handle) # pylint: disable=not-callable
|
ret = self.next(data_handle) # pylint: disable=not-callable
|
||||||
except Exception as e: # pylint: disable=broad-except
|
except Exception as e: # pylint: disable=broad-except
|
||||||
tb = sys.exc_info()[2]
|
tb = sys.exc_info()[2]
|
||||||
|
# On dask the worker is restarted and somehow the information is
|
||||||
|
# lost.
|
||||||
self.exception = e.with_traceback(tb)
|
self.exception = e.with_traceback(tb)
|
||||||
return 0
|
return 0
|
||||||
return ret
|
return ret
|
||||||
@ -453,41 +430,20 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
self.handle = None
|
self.handle = None
|
||||||
return
|
return
|
||||||
|
|
||||||
handler = self._get_data_handler(data)
|
from .data import dispatch_data_backend
|
||||||
can_handle_meta = False
|
handle, feature_names, feature_types = dispatch_data_backend(
|
||||||
if handler is None:
|
data, missing=self.missing,
|
||||||
data = _convert_unknown_data(data, None)
|
threads=self.nthread,
|
||||||
handler = self._get_data_handler(data)
|
feature_names=feature_names,
|
||||||
try:
|
feature_types=feature_types)
|
||||||
handler.handle_meta(label, weight, base_margin)
|
assert handle is not None
|
||||||
can_handle_meta = True
|
self.handle = handle
|
||||||
except NotImplementedError:
|
|
||||||
can_handle_meta = False
|
|
||||||
|
|
||||||
self.handle, feature_names, feature_types = handler.handle_input(
|
self.set_info(label=label, weight=weight, base_margin=base_margin)
|
||||||
data, feature_names, feature_types)
|
|
||||||
assert self.handle, 'Failed to construct a DMatrix.'
|
|
||||||
|
|
||||||
if not can_handle_meta:
|
|
||||||
self.set_info(label, weight, base_margin)
|
|
||||||
|
|
||||||
self.feature_names = feature_names
|
self.feature_names = feature_names
|
||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
|
|
||||||
def _get_data_handler(self, data, meta=None, meta_type=None):
|
|
||||||
'''Get data handler for this DMatrix class.'''
|
|
||||||
from .data import get_dmatrix_data_handler
|
|
||||||
handler = get_dmatrix_data_handler(
|
|
||||||
data, self.missing, self.nthread, self.silent, meta, meta_type)
|
|
||||||
return handler
|
|
||||||
|
|
||||||
# pylint: disable=no-self-use
|
|
||||||
def _get_meta_handler(self, data, meta, meta_type):
|
|
||||||
from .data import get_dmatrix_meta_handler
|
|
||||||
handler = get_dmatrix_meta_handler(
|
|
||||||
data, meta, meta_type)
|
|
||||||
return handler
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if hasattr(self, "handle") and self.handle:
|
if hasattr(self, "handle") and self.handle:
|
||||||
_check_call(_LIB.XGDMatrixFree(self.handle))
|
_check_call(_LIB.XGDMatrixFree(self.handle))
|
||||||
@ -497,7 +453,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
label=None, weight=None, base_margin=None,
|
label=None, weight=None, base_margin=None,
|
||||||
group=None,
|
group=None,
|
||||||
label_lower_bound=None,
|
label_lower_bound=None,
|
||||||
label_upper_bound=None):
|
label_upper_bound=None,
|
||||||
|
feature_names=None,
|
||||||
|
feature_types=None):
|
||||||
'''Set meta info for DMatrix.'''
|
'''Set meta info for DMatrix.'''
|
||||||
if label is not None:
|
if label is not None:
|
||||||
self.set_label(label)
|
self.set_label(label)
|
||||||
@ -511,6 +469,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
self.set_float_info('label_lower_bound', label_lower_bound)
|
self.set_float_info('label_lower_bound', label_lower_bound)
|
||||||
if label_upper_bound is not None:
|
if label_upper_bound is not None:
|
||||||
self.set_float_info('label_upper_bound', label_upper_bound)
|
self.set_float_info('label_upper_bound', label_upper_bound)
|
||||||
|
if feature_names is not None:
|
||||||
|
self.feature_names = feature_names
|
||||||
|
if feature_types is not None:
|
||||||
|
self.feature_types = feature_types
|
||||||
|
|
||||||
def get_float_info(self, field):
|
def get_float_info(self, field):
|
||||||
"""Get float property from the DMatrix.
|
"""Get float property from the DMatrix.
|
||||||
@ -565,17 +527,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
data: numpy array
|
data: numpy array
|
||||||
The array of data to be set
|
The array of data to be set
|
||||||
"""
|
"""
|
||||||
if isinstance(data, np.ndarray):
|
from .data import dispatch_meta_backend
|
||||||
self.set_float_info_npy2d(field, data)
|
dispatch_meta_backend(self, data, field, 'float')
|
||||||
return
|
|
||||||
handler = self._get_data_handler(data, field, np.float32)
|
|
||||||
assert handler
|
|
||||||
data, _, _ = handler.transform(data)
|
|
||||||
c_data = c_array(ctypes.c_float, data)
|
|
||||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
|
||||||
c_str(field),
|
|
||||||
c_data,
|
|
||||||
c_bst_ulong(len(data))))
|
|
||||||
|
|
||||||
def set_float_info_npy2d(self, field, data):
|
def set_float_info_npy2d(self, field, data):
|
||||||
"""Set float type property into the DMatrix
|
"""Set float type property into the DMatrix
|
||||||
@ -589,13 +542,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
data: numpy array
|
data: numpy array
|
||||||
The array of data to be set
|
The array of data to be set
|
||||||
"""
|
"""
|
||||||
data, _, _ = self._get_meta_handler(
|
from .data import dispatch_meta_backend
|
||||||
data, field, np.float32).transform(data)
|
dispatch_meta_backend(self, data, field, 'float')
|
||||||
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
|
||||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
|
||||||
c_str(field),
|
|
||||||
c_data,
|
|
||||||
c_bst_ulong(len(data))))
|
|
||||||
|
|
||||||
def set_uint_info(self, field, data):
|
def set_uint_info(self, field, data):
|
||||||
"""Set uint type property into the DMatrix.
|
"""Set uint type property into the DMatrix.
|
||||||
@ -608,27 +556,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
data: numpy array
|
data: numpy array
|
||||||
The array of data to be set
|
The array of data to be set
|
||||||
"""
|
"""
|
||||||
data, _, _ = self._get_data_handler(
|
from .data import dispatch_meta_backend
|
||||||
data, field, 'uint32').transform(data)
|
dispatch_meta_backend(self, data, field, 'uint32')
|
||||||
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
|
|
||||||
c_str(field),
|
|
||||||
c_array(ctypes.c_uint, data),
|
|
||||||
c_bst_ulong(len(data))))
|
|
||||||
|
|
||||||
def set_interface_info(self, field, data):
|
|
||||||
"""Set info type property into DMatrix."""
|
|
||||||
# If we are passed a dataframe, extract the series
|
|
||||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
|
||||||
if len(data.columns) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
'Expecting meta-info to contain a single column')
|
|
||||||
data = data[data.columns[0]]
|
|
||||||
|
|
||||||
interface = bytes(json.dumps([data.__cuda_array_interface__],
|
|
||||||
indent=2), 'utf-8')
|
|
||||||
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
|
|
||||||
c_str(field),
|
|
||||||
interface))
|
|
||||||
|
|
||||||
def save_binary(self, fname, silent=True):
|
def save_binary(self, fname, silent=True):
|
||||||
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
"""Save DMatrix to an XGBoost buffer. Saved binary can be later loaded
|
||||||
@ -653,10 +582,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
label: array like
|
label: array like
|
||||||
The label information to be set into DMatrix
|
The label information to be set into DMatrix
|
||||||
"""
|
"""
|
||||||
if _has_cuda_array_interface(label):
|
from .data import dispatch_meta_backend
|
||||||
self.set_interface_info('label', label)
|
dispatch_meta_backend(self, label, 'label', 'float')
|
||||||
else:
|
|
||||||
self.set_float_info('label', label)
|
|
||||||
|
|
||||||
def set_weight(self, weight):
|
def set_weight(self, weight):
|
||||||
"""Set weight of each instance.
|
"""Set weight of each instance.
|
||||||
@ -674,10 +601,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
sense to assign weights to individual data points.
|
sense to assign weights to individual data points.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if _has_cuda_array_interface(weight):
|
from .data import dispatch_meta_backend
|
||||||
self.set_interface_info('weight', weight)
|
dispatch_meta_backend(self, weight, 'weight', 'float')
|
||||||
else:
|
|
||||||
self.set_float_info('weight', weight)
|
|
||||||
|
|
||||||
def set_base_margin(self, margin):
|
def set_base_margin(self, margin):
|
||||||
"""Set base margin of booster to start from.
|
"""Set base margin of booster to start from.
|
||||||
@ -693,10 +618,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
Prediction margin of each datapoint
|
Prediction margin of each datapoint
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if _has_cuda_array_interface(margin):
|
from .data import dispatch_meta_backend
|
||||||
self.set_interface_info('base_margin', margin)
|
dispatch_meta_backend(self, margin, 'base_margin', 'float')
|
||||||
else:
|
|
||||||
self.set_float_info('base_margin', margin)
|
|
||||||
|
|
||||||
def set_group(self, group):
|
def set_group(self, group):
|
||||||
"""Set group size of DMatrix (used for ranking).
|
"""Set group size of DMatrix (used for ranking).
|
||||||
@ -706,10 +629,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
group : array like
|
group : array like
|
||||||
Group size of each group
|
Group size of each group
|
||||||
"""
|
"""
|
||||||
if _has_cuda_array_interface(group):
|
from .data import dispatch_meta_backend
|
||||||
self.set_interface_info('group', group)
|
dispatch_meta_backend(self, group, 'group', 'uint32')
|
||||||
else:
|
|
||||||
self.set_uint_info('group', group)
|
|
||||||
|
|
||||||
def get_label(self):
|
def get_label(self):
|
||||||
"""Get the label of the DMatrix.
|
"""Get the label of the DMatrix.
|
||||||
@ -830,7 +751,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
|
|
||||||
if len(feature_names) != len(set(feature_names)):
|
if len(feature_names) != len(set(feature_names)):
|
||||||
raise ValueError('feature_names must be unique')
|
raise ValueError('feature_names must be unique')
|
||||||
if len(feature_names) != self.num_col():
|
if len(feature_names) != self.num_col() and self.num_col() != 0:
|
||||||
msg = 'feature_names must have the same length as data'
|
msg = 'feature_names must have the same length as data'
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
# prohibit to use symbols may affect to parse. e.g. []<
|
# prohibit to use symbols may affect to parse. e.g. []<
|
||||||
@ -935,28 +856,35 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data, label=None, weight=None, base_margin=None,
|
def __init__(self, data, label=None, weight=None, # pylint: disable=W0231
|
||||||
|
base_margin=None,
|
||||||
missing=None,
|
missing=None,
|
||||||
silent=False,
|
silent=False,
|
||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None,
|
feature_types=None,
|
||||||
nthread=None, max_bin=256):
|
nthread=None, max_bin=256):
|
||||||
self.max_bin = max_bin
|
self.max_bin = max_bin
|
||||||
|
self.missing = missing if missing is not None else np.nan
|
||||||
|
self.nthread = nthread if nthread is not None else 1
|
||||||
|
|
||||||
if isinstance(data, ctypes.c_void_p):
|
if isinstance(data, ctypes.c_void_p):
|
||||||
self.handle = data
|
self.handle = data
|
||||||
return
|
return
|
||||||
super().__init__(data, label=label, weight=weight,
|
from .data import init_device_quantile_dmatrix
|
||||||
|
handle, feature_names, feature_types = init_device_quantile_dmatrix(
|
||||||
|
data, missing=self.missing, threads=self.nthread,
|
||||||
|
max_bin=self.max_bin,
|
||||||
|
label=label, weight=weight,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
missing=missing,
|
group=None,
|
||||||
silent=silent,
|
label_lower_bound=None,
|
||||||
|
label_upper_bound=None,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types)
|
||||||
nthread=nthread)
|
self.handle = handle
|
||||||
|
|
||||||
def _get_data_handler(self, data, meta=None, meta_type=None):
|
self.feature_names = feature_names
|
||||||
from .data import get_device_quantile_dmatrix_data_handler
|
self.feature_types = feature_types
|
||||||
return get_device_quantile_dmatrix_data_handler(
|
|
||||||
data, self.max_bin, self.missing, self.nthread, self.silent)
|
|
||||||
|
|
||||||
def _set_data_from_cuda_interface(self, data):
|
def _set_data_from_cuda_interface(self, data):
|
||||||
'''Set data from CUDA array interface.'''
|
'''Set data from CUDA array interface.'''
|
||||||
@ -971,6 +899,7 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
|
|
||||||
def _set_data_from_cuda_columnar(self, data):
|
def _set_data_from_cuda_columnar(self, data):
|
||||||
'''Set data from CUDA columnar format.1'''
|
'''Set data from CUDA columnar format.1'''
|
||||||
|
from .data import _cudf_array_interfaces
|
||||||
interfaces_str = _cudf_array_interfaces(data)
|
interfaces_str = _cudf_array_interfaces(data)
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
|
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
|
||||||
@ -1592,6 +1521,7 @@ class Booster(object):
|
|||||||
rows = data.shape[0]
|
rows = data.shape[0]
|
||||||
return reshape_output(mem, rows)
|
return reshape_output(mem, rows)
|
||||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||||
|
from .data import _cudf_array_interfaces
|
||||||
interfaces_str = _cudf_array_interfaces(data)
|
interfaces_str = _cudf_array_interfaces(data)
|
||||||
_check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns(
|
_check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns(
|
||||||
self.handle,
|
self.handle,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -12,12 +12,16 @@ void DMatrixProxy::FromCudaColumnar(std::string interface_str) {
|
|||||||
auto const& value = adapter->Value();
|
auto const& value = adapter->Value();
|
||||||
this->batch_ = adapter;
|
this->batch_ = adapter;
|
||||||
device_ = adapter->DeviceIdx();
|
device_ = adapter->DeviceIdx();
|
||||||
|
this->Info().num_col_ = adapter->NumColumns();
|
||||||
|
this->Info().num_row_ = adapter->NumRows();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DMatrixProxy::FromCudaArray(std::string interface_str) {
|
void DMatrixProxy::FromCudaArray(std::string interface_str) {
|
||||||
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str));
|
std::shared_ptr<CupyAdapter> adapter(new CupyAdapter(interface_str));
|
||||||
this->batch_ = adapter;
|
this->batch_ = adapter;
|
||||||
device_ = adapter->DeviceIdx();
|
device_ = adapter->DeviceIdx();
|
||||||
|
this->Info().num_col_ = adapter->NumColumns();
|
||||||
|
this->Info().num_row_ = adapter->NumRows();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
|||||||
@ -12,11 +12,12 @@ import testing as tm
|
|||||||
class TestDeviceQuantileDMatrix(unittest.TestCase):
|
class TestDeviceQuantileDMatrix(unittest.TestCase):
|
||||||
def test_dmatrix_numpy_init(self):
|
def test_dmatrix_numpy_init(self):
|
||||||
data = np.random.randn(5, 5)
|
data = np.random.randn(5, 5)
|
||||||
with pytest.raises(AssertionError, match='is not supported for DeviceQuantileDMatrix'):
|
with pytest.raises(TypeError,
|
||||||
dm = xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
|
match='is not supported for DeviceQuantileDMatrix'):
|
||||||
|
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
def test_dmatrix_cupy_init(self):
|
def test_dmatrix_cupy_init(self):
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
data = cp.random.randn(5, 5)
|
data = cp.random.randn(5, 5)
|
||||||
dm = xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
xgb.DeviceQuantileDMatrix(data, cp.ones(5, dtype=np.float64))
|
||||||
|
|||||||
@ -119,10 +119,10 @@ def _test_cudf_metainfo(DMatrixT):
|
|||||||
dmat.set_float_info('label', floats)
|
dmat.set_float_info('label', floats)
|
||||||
dmat.set_float_info('base_margin', floats)
|
dmat.set_float_info('base_margin', floats)
|
||||||
dmat.set_uint_info('group', uints)
|
dmat.set_uint_info('group', uints)
|
||||||
dmat_cudf.set_interface_info('weight', cudf_floats)
|
dmat_cudf.set_info(weight=cudf_floats)
|
||||||
dmat_cudf.set_interface_info('label', cudf_floats)
|
dmat_cudf.set_info(label=cudf_floats)
|
||||||
dmat_cudf.set_interface_info('base_margin', cudf_floats)
|
dmat_cudf.set_info(base_margin=cudf_floats)
|
||||||
dmat_cudf.set_interface_info('group', cudf_uints)
|
dmat_cudf.set_info(group=cudf_uints)
|
||||||
|
|
||||||
# Test setting info with cudf DataFrame
|
# Test setting info with cudf DataFrame
|
||||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
||||||
@ -132,10 +132,10 @@ def _test_cudf_metainfo(DMatrixT):
|
|||||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
||||||
|
|
||||||
# Test setting info with cudf Series
|
# Test setting info with cudf Series
|
||||||
dmat_cudf.set_interface_info('weight', cudf_floats[cudf_floats.columns[0]])
|
dmat_cudf.set_info(weight=cudf_floats[cudf_floats.columns[0]])
|
||||||
dmat_cudf.set_interface_info('label', cudf_floats[cudf_floats.columns[0]])
|
dmat_cudf.set_info(label=cudf_floats[cudf_floats.columns[0]])
|
||||||
dmat_cudf.set_interface_info('base_margin', cudf_floats[cudf_floats.columns[0]])
|
dmat_cudf.set_info(base_margin=cudf_floats[cudf_floats.columns[0]])
|
||||||
dmat_cudf.set_interface_info('group', cudf_uints[cudf_uints.columns[0]])
|
dmat_cudf.set_info(group=cudf_uints[cudf_uints.columns[0]])
|
||||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
||||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
|
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
|
||||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||||
|
|||||||
@ -92,10 +92,10 @@ def _test_cupy_metainfo(DMatrixT):
|
|||||||
dmat.set_float_info('label', floats)
|
dmat.set_float_info('label', floats)
|
||||||
dmat.set_float_info('base_margin', floats)
|
dmat.set_float_info('base_margin', floats)
|
||||||
dmat.set_uint_info('group', uints)
|
dmat.set_uint_info('group', uints)
|
||||||
dmat_cupy.set_interface_info('weight', cupy_floats)
|
dmat_cupy.set_info(weight=cupy_floats)
|
||||||
dmat_cupy.set_interface_info('label', cupy_floats)
|
dmat_cupy.set_info(label=cupy_floats)
|
||||||
dmat_cupy.set_interface_info('base_margin', cupy_floats)
|
dmat_cupy.set_info(base_margin=cupy_floats)
|
||||||
dmat_cupy.set_interface_info('group', cupy_uints)
|
dmat_cupy.set_info(group=cupy_uints)
|
||||||
|
|
||||||
# Test setting info with cupy
|
# Test setting info with cupy
|
||||||
assert np.array_equal(dmat.get_float_info('weight'),
|
assert np.array_equal(dmat.get_float_info('weight'),
|
||||||
|
|||||||
@ -1,17 +1,14 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import sys
|
import sys
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
try:
|
|
||||||
# python 2
|
|
||||||
from StringIO import StringIO
|
|
||||||
except ImportError:
|
|
||||||
# python 3
|
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import unittest
|
import unittest
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import tempfile
|
||||||
|
|
||||||
dpath = 'demo/data/'
|
dpath = 'demo/data/'
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
@ -66,13 +63,16 @@ class TestBasic(unittest.TestCase):
|
|||||||
# error must be smaller than 10%
|
# error must be smaller than 10%
|
||||||
assert err < 0.1
|
assert err < 0.1
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
dtest_path = os.path.join(tmpdir, 'dtest.dmatrix')
|
||||||
# save dmatrix into binary buffer
|
# save dmatrix into binary buffer
|
||||||
dtest.save_binary('dtest.buffer')
|
dtest.save_binary(dtest_path)
|
||||||
# save model
|
# save model
|
||||||
bst.save_model('xgb.model')
|
model_path = os.path.join(tmpdir, 'model.booster')
|
||||||
|
bst.save_model(model_path)
|
||||||
# load model and data in
|
# load model and data in
|
||||||
bst2 = xgb.Booster(model_file='xgb.model')
|
bst2 = xgb.Booster(model_file=model_path)
|
||||||
dtest2 = xgb.DMatrix('dtest.buffer')
|
dtest2 = xgb.DMatrix(dtest_path)
|
||||||
preds2 = bst2.predict(dtest2)
|
preds2 = bst2.predict(dtest2)
|
||||||
# assert they are the same
|
# assert they are the same
|
||||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||||
|
|||||||
@ -67,8 +67,7 @@ class TestPandas(unittest.TestCase):
|
|||||||
# 0 1 1 0 0
|
# 0 1 1 0 0
|
||||||
# 1 2 0 1 0
|
# 1 2 0 1 0
|
||||||
# 2 3 0 0 1
|
# 2 3 0 0 1
|
||||||
pandas_handler = xgb.data.PandasHandler(np.nan, 0, False)
|
result, _, _ = xgb.data._transform_pandas_df(dummies)
|
||||||
result, _, _ = pandas_handler._maybe_pandas_data(dummies, None, None)
|
|
||||||
exp = np.array([[1., 1., 0., 0.],
|
exp = np.array([[1., 1., 0., 0.],
|
||||||
[2., 0., 1., 0.],
|
[2., 0., 1., 0.],
|
||||||
[3., 0., 0., 1.]])
|
[3., 0., 0., 1.]])
|
||||||
@ -129,17 +128,16 @@ class TestPandas(unittest.TestCase):
|
|||||||
def test_pandas_label(self):
|
def test_pandas_label(self):
|
||||||
# label must be a single column
|
# label must be a single column
|
||||||
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
df = pd.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||||
pandas_handler = xgb.data.PandasHandler(np.nan, 0, False)
|
self.assertRaises(ValueError, xgb.data._transform_pandas_df, df,
|
||||||
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
|
|
||||||
None, None, 'label', 'float')
|
None, None, 'label', 'float')
|
||||||
|
|
||||||
# label must be supported dtype
|
# label must be supported dtype
|
||||||
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
|
df = pd.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
|
||||||
self.assertRaises(ValueError, pandas_handler._maybe_pandas_data, df,
|
self.assertRaises(ValueError, xgb.data._transform_pandas_df, df,
|
||||||
None, None, 'label', 'float')
|
None, None, 'label', 'float')
|
||||||
|
|
||||||
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
|
df = pd.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
|
||||||
result, _, _ = pandas_handler._maybe_pandas_data(df, None, None,
|
result, _, _ = xgb.data._transform_pandas_df(df, None, None,
|
||||||
'label', 'float')
|
'label', 'float')
|
||||||
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
|
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
|
||||||
dtype=float))
|
dtype=float))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user