parent
048d969be4
commit
a3ec964346
109
demo/guide-python/data_iterator.py
Normal file
109
demo/guide-python/data_iterator.py
Normal file
@ -0,0 +1,109 @@
|
||||
'''A demo for defining data iterator.
|
||||
|
||||
The demo that defines a customized iterator for passing batches of data into
|
||||
`xgboost.DeviceQuantileDMatrix` and use this `DeviceQuantileDMatrix` for
|
||||
training. The feature is used primarily designed to reduce the required GPU
|
||||
memory for training on distributed environment.
|
||||
|
||||
Aftering going through the demo, one might ask why don't we use more native
|
||||
Python iterator? That's because XGBoost requires a `reset` function, while
|
||||
using `itertools.tee` might incur significant memory usage according to:
|
||||
|
||||
https://docs.python.org/3/library/itertools.html#itertools.tee.
|
||||
|
||||
'''
|
||||
|
||||
import xgboost
|
||||
import cupy
|
||||
import numpy
|
||||
|
||||
COLS = 64
|
||||
ROWS_PER_BATCH = 1000 # data is splited by rows
|
||||
BATCHES = 32
|
||||
|
||||
|
||||
class IterForDMatrixDemo(xgboost.core.DataIter):
|
||||
'''A data iterator for XGBoost DMatrix.
|
||||
|
||||
`reset` and `next` are required for any data iterator, other functions here
|
||||
are utilites for demonstration's purpose.
|
||||
|
||||
'''
|
||||
def __init__(self):
|
||||
'''Generate some random data for demostration.
|
||||
|
||||
Actual data can be anything that is currently supported by XGBoost.
|
||||
'''
|
||||
self.rows = ROWS_PER_BATCH
|
||||
self.cols = COLS
|
||||
rng = cupy.random.RandomState(1994)
|
||||
self._data = [rng.randn(self.rows, self.cols)] * BATCHES
|
||||
self._labels = [rng.randn(self.rows)] * BATCHES
|
||||
self._weights = [rng.randn(self.rows)] * BATCHES
|
||||
|
||||
self.it = 0 # set iterator to 0
|
||||
super().__init__()
|
||||
|
||||
def as_array(self):
|
||||
return cupy.concatenate(self._data)
|
||||
|
||||
def as_array_labels(self):
|
||||
return cupy.concatenate(self._labels)
|
||||
|
||||
def as_array_weights(self):
|
||||
return cupy.concatenate(self._weights)
|
||||
|
||||
def data(self):
|
||||
'''Utility function for obtaining current batch of data.'''
|
||||
return self._data[self.it]
|
||||
|
||||
def labels(self):
|
||||
'''Utility function for obtaining current batch of label.'''
|
||||
return self._labels[self.it]
|
||||
|
||||
def weights(self):
|
||||
return self._weights[self.it]
|
||||
|
||||
def reset(self):
|
||||
'''Reset the iterator'''
|
||||
self.it = 0
|
||||
|
||||
def next(self, input_data):
|
||||
'''Yield next batch of data.'''
|
||||
if self.it == len(self._data):
|
||||
# Return 0 when there's no more batch.
|
||||
return 0
|
||||
input_data(data=self.data(), label=self.labels(),
|
||||
weight=self.weights())
|
||||
self.it += 1
|
||||
return 1
|
||||
|
||||
|
||||
def main():
|
||||
rounds = 100
|
||||
it = IterForDMatrixDemo()
|
||||
|
||||
# Use iterator, must be `DeviceQuantileDMatrix`
|
||||
m_with_it = xgboost.DeviceQuantileDMatrix(it)
|
||||
|
||||
# Use regular DMatrix.
|
||||
m = xgboost.DMatrix(it.as_array(), it.as_array_labels(),
|
||||
weight=it.as_array_weights())
|
||||
|
||||
assert m_with_it.num_col() == m.num_col()
|
||||
assert m_with_it.num_row() == m.num_row()
|
||||
|
||||
reg_with_it = xgboost.train({'tree_method': 'gpu_hist'}, m_with_it,
|
||||
num_boost_round=rounds)
|
||||
predict_with_it = reg_with_it.predict(m_with_it)
|
||||
|
||||
reg = xgboost.train({'tree_method': 'gpu_hist'}, m,
|
||||
num_boost_round=rounds)
|
||||
predict = reg.predict(m)
|
||||
|
||||
numpy.testing.assert_allclose(predict_with_it, predict,
|
||||
rtol=1e6)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -300,6 +300,99 @@ def _cudf_array_interfaces(df):
|
||||
return interfaces_str
|
||||
|
||||
|
||||
class DataIter:
|
||||
'''The interface for user defined data iterator. Currently is only
|
||||
supported by Device DMatrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
rows : int
|
||||
Total number of rows combining all batches.
|
||||
cols : int
|
||||
Number of columns for each batch.
|
||||
'''
|
||||
def __init__(self):
|
||||
proxy_handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(proxy_handle)))
|
||||
self._handle = DeviceQuantileDMatrix(proxy_handle)
|
||||
self.exception = None
|
||||
|
||||
@property
|
||||
def proxy(self):
|
||||
'''Handler of DMatrix proxy.'''
|
||||
return self._handle
|
||||
|
||||
def reset_wrapper(self, this): # pylint: disable=unused-argument
|
||||
'''A wrapper for user defined `reset` function.'''
|
||||
self.reset()
|
||||
|
||||
def next_wrapper(self, this): # pylint: disable=unused-argument
|
||||
'''A wrapper for user defined `next` function.
|
||||
|
||||
`this` is not used in Python. ctypes can handle `self` of a Python
|
||||
member function automatically when converting a it to c function
|
||||
pointer.
|
||||
|
||||
'''
|
||||
if self.exception is not None:
|
||||
return 0
|
||||
|
||||
def data_handle(data, label=None, weight=None, base_margin=None,
|
||||
group=None,
|
||||
label_lower_bound=None, label_upper_bound=None):
|
||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||
# pylint: disable=protected-access
|
||||
self.proxy._set_data_from_cuda_columnar(data)
|
||||
elif lazy_isinstance(data, 'cudf.core.series', 'Series'):
|
||||
# pylint: disable=protected-access
|
||||
self.proxy._set_data_from_cuda_columnar(data)
|
||||
elif lazy_isinstance(data, 'cupy.core.core', 'ndarray'):
|
||||
# pylint: disable=protected-access
|
||||
self.proxy._set_data_from_cuda_interface(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
'Value type is not supported for data iterator:' +
|
||||
str(type(self._handle)), type(data))
|
||||
self.proxy.set_info(label=label, weight=weight,
|
||||
base_margin=base_margin,
|
||||
group=group,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound)
|
||||
try:
|
||||
# Deffer the exception in order to return 0 and stop the iteration.
|
||||
# Exception inside a ctype callback function has no effect except
|
||||
# for printing to stderr (doesn't stop the execution).
|
||||
ret = self.next(data_handle) # pylint: disable=not-callable
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
tb = sys.exc_info()[2]
|
||||
print('Got an exception in Python')
|
||||
self.exception = e.with_traceback(tb)
|
||||
return 0
|
||||
return ret
|
||||
|
||||
def reset(self):
|
||||
'''Reset the data iterator. Prototype for user defined function.'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def next(self, input_data):
|
||||
'''Set the next batch of data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
data_handle: callable
|
||||
A function with same data fields like `data`, `label` with
|
||||
`xgboost.DMatrix`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
0 if there's no more batch, otherwise 1.
|
||||
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
"""Data Matrix used in XGBoost.
|
||||
|
||||
@ -361,36 +454,65 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
self.handle = None
|
||||
return
|
||||
|
||||
handler = self.get_data_handler(data)
|
||||
handler = self._get_data_handler(data)
|
||||
can_handle_meta = False
|
||||
if handler is None:
|
||||
data = _convert_unknown_data(data, None)
|
||||
handler = self.get_data_handler(data)
|
||||
handler = self._get_data_handler(data)
|
||||
try:
|
||||
handler.handle_meta(label, weight, base_margin)
|
||||
can_handle_meta = True
|
||||
except NotImplementedError:
|
||||
can_handle_meta = False
|
||||
|
||||
self.handle, feature_names, feature_types = handler.handle_input(
|
||||
data, feature_names, feature_types)
|
||||
assert self.handle, 'Failed to construct a DMatrix.'
|
||||
|
||||
if label is not None:
|
||||
self.set_label(label)
|
||||
if weight is not None:
|
||||
self.set_weight(weight)
|
||||
if base_margin is not None:
|
||||
self.set_base_margin(base_margin)
|
||||
if not can_handle_meta:
|
||||
self.set_info(label, weight, base_margin)
|
||||
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
|
||||
def get_data_handler(self, data, meta=None, meta_type=None):
|
||||
def _get_data_handler(self, data, meta=None, meta_type=None):
|
||||
'''Get data handler for this DMatrix class.'''
|
||||
from .data import get_dmatrix_data_handler
|
||||
handler = get_dmatrix_data_handler(
|
||||
data, self.missing, self.nthread, self.silent, meta, meta_type)
|
||||
return handler
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def _get_meta_handler(self, data, meta, meta_type):
|
||||
from .data import get_dmatrix_meta_handler
|
||||
handler = get_dmatrix_meta_handler(
|
||||
data, meta, meta_type)
|
||||
return handler
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "handle") and self.handle:
|
||||
_check_call(_LIB.XGDMatrixFree(self.handle))
|
||||
self.handle = None
|
||||
|
||||
def set_info(self,
|
||||
label=None, weight=None, base_margin=None,
|
||||
group=None,
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None):
|
||||
'''Set meta info for DMatrix.'''
|
||||
if label is not None:
|
||||
self.set_label(label)
|
||||
if weight is not None:
|
||||
self.set_weight(weight)
|
||||
if base_margin is not None:
|
||||
self.set_base_margin(base_margin)
|
||||
if group is not None:
|
||||
self.set_group(group)
|
||||
if label_lower_bound is not None:
|
||||
self.set_float_info('label_lower_bound', label_lower_bound)
|
||||
if label_upper_bound is not None:
|
||||
self.set_float_info('label_upper_bound', label_upper_bound)
|
||||
|
||||
def get_float_info(self, field):
|
||||
"""Get float property from the DMatrix.
|
||||
|
||||
@ -447,10 +569,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
if isinstance(data, np.ndarray):
|
||||
self.set_float_info_npy2d(field, data)
|
||||
return
|
||||
handler = self.get_data_handler(data, field, np.float32)
|
||||
if handler is None:
|
||||
data = _convert_unknown_data(data, field, np.float32)
|
||||
handler = self.get_data_handler(data, field, np.float32)
|
||||
handler = self._get_data_handler(data, field, np.float32)
|
||||
assert handler
|
||||
data, _, _ = handler.transform(data)
|
||||
c_data = c_array(ctypes.c_float, data)
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||
@ -470,7 +590,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
data: numpy array
|
||||
The array of data to be set
|
||||
"""
|
||||
data, _, _ = self.get_data_handler(
|
||||
data, _, _ = self._get_meta_handler(
|
||||
data, field, np.float32).transform(data)
|
||||
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
|
||||
@ -489,7 +609,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
data: numpy array
|
||||
The array of data to be set
|
||||
"""
|
||||
data, _, _ = self.get_data_handler(
|
||||
data, _, _ = self._get_data_handler(
|
||||
data, field, 'uint32').transform(data)
|
||||
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
|
||||
c_str(field),
|
||||
@ -803,11 +923,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
|
||||
class DeviceQuantileDMatrix(DMatrix):
|
||||
"""Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do not
|
||||
use this for test/validation tasks as some information may be lost in quantisation. This
|
||||
DMatrix is primarily designed to save memory in training from device memory inputs by
|
||||
avoiding intermediate storage. Implementation does not currently consider weights in
|
||||
quantisation process(unlike DMatrix). Set max_bin to control the number of bins during
|
||||
"""Device memory Data Matrix used in XGBoost for training with
|
||||
tree_method='gpu_hist'. Do not use this for test/validation tasks as some
|
||||
information may be lost in quantisation. This DMatrix is primarily designed
|
||||
to save memory in training from device memory inputs by avoiding
|
||||
intermediate storage. Set max_bin to control the number of bins during
|
||||
quantisation.
|
||||
|
||||
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
|
||||
@ -823,6 +943,9 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
feature_types=None,
|
||||
nthread=None, max_bin=256):
|
||||
self.max_bin = max_bin
|
||||
if isinstance(data, ctypes.c_void_p):
|
||||
self.handle = data
|
||||
return
|
||||
super().__init__(data, label=label, weight=weight,
|
||||
base_margin=base_margin,
|
||||
missing=missing,
|
||||
@ -831,11 +954,32 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
feature_types=feature_types,
|
||||
nthread=nthread)
|
||||
|
||||
def get_data_handler(self, data, meta=None, meta_type=None):
|
||||
def _get_data_handler(self, data, meta=None, meta_type=None):
|
||||
from .data import get_device_quantile_dmatrix_data_handler
|
||||
return get_device_quantile_dmatrix_data_handler(
|
||||
data, self.max_bin, self.missing, self.nthread, self.silent)
|
||||
|
||||
def _set_data_from_cuda_interface(self, data):
|
||||
'''Set data from CUDA array interface.'''
|
||||
interface = data.__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixSetDataCudaArrayInterface(
|
||||
self.handle,
|
||||
interface_str
|
||||
)
|
||||
)
|
||||
|
||||
def _set_data_from_cuda_columnar(self, data):
|
||||
'''Set data from CUDA columnar format.1'''
|
||||
interfaces_str = _cudf_array_interfaces(data)
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
|
||||
self.handle,
|
||||
interfaces_str
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Booster(object):
|
||||
# pylint: disable=too-many-public-methods
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# pylint: disable=too-many-arguments, no-self-use
|
||||
# pylint: disable=too-many-arguments, no-self-use, too-many-instance-attributes
|
||||
'''Data dispatching for DMatrix.'''
|
||||
import ctypes
|
||||
import abc
|
||||
@ -8,6 +8,7 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from .core import c_array, _LIB, _check_call, c_str, _cudf_array_interfaces
|
||||
from .core import DataIter
|
||||
from .compat import lazy_isinstance, STRING_TYPES, os_fspath, os_PathLike
|
||||
|
||||
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
||||
@ -23,6 +24,18 @@ class DataHandler(abc.ABC):
|
||||
self.meta = meta
|
||||
self.meta_type = meta_type
|
||||
|
||||
def handle_meta(self, label=None, weight=None, base_margin=None,
|
||||
group=None,
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None):
|
||||
'''Handle meta data when the DMatrix type can not defer setting meta
|
||||
data after construction. Example is `DeviceQuantileDMatrix`
|
||||
which requires weight to be presented before digesting
|
||||
data.
|
||||
|
||||
'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def _warn_unused_missing(self, data):
|
||||
if not (np.isnan(np.nan) or None):
|
||||
warnings.warn(
|
||||
@ -116,6 +129,14 @@ def get_dmatrix_data_handler(data, missing, nthread, silent,
|
||||
return handler(missing, nthread, silent, meta, meta_type)
|
||||
|
||||
|
||||
def get_dmatrix_meta_handler(data, meta, meta_type):
|
||||
'''Get handler for meta instead of data.'''
|
||||
handler = __dmatrix_registry.get_handler(data)
|
||||
if handler is None:
|
||||
return None
|
||||
return handler(None, 0, True, meta, meta_type)
|
||||
|
||||
|
||||
class FileHandler(DataHandler):
|
||||
'''Handler of path like input.'''
|
||||
def handle_input(self, data, feature_names, feature_types):
|
||||
@ -511,6 +532,43 @@ __dmatrix_registry.register_handler_opaque(
|
||||
DLPackHandler)
|
||||
|
||||
|
||||
class SingleBatchInternalIter(DataIter):
|
||||
'''An iterator for single batch data to help creating device DMatrix.
|
||||
Transforming input directly to histogram with normal single batch data API
|
||||
can not access weight for sketching. So this iterator acts as a staging
|
||||
area for meta info.
|
||||
|
||||
'''
|
||||
def __init__(self, data, label, weight, base_margin, group,
|
||||
label_lower_bound, label_upper_bound):
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.weight = weight
|
||||
self.base_margin = base_margin
|
||||
self.group = group
|
||||
self.label_lower_bound = label_lower_bound
|
||||
self.label_upper_bound = label_upper_bound
|
||||
self.it = 0 # pylint: disable=invalid-name
|
||||
super().__init__()
|
||||
|
||||
def next(self, input_data):
|
||||
if self.it == 1:
|
||||
return 0
|
||||
self.it += 1
|
||||
input_data(data=self.data, label=self.label,
|
||||
weight=self.weight, base_margin=self.base_margin,
|
||||
group=self.group,
|
||||
label_lower_bound=self.label_lower_bound,
|
||||
label_upper_bound=self.label_upper_bound)
|
||||
return 1
|
||||
|
||||
def reset(self):
|
||||
self.it = 0
|
||||
|
||||
|
||||
__device_quantile_dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DeviceQuantileDMatrixDataHandler(DataHandler): # pylint: disable=abstract-method
|
||||
'''Base class of data handler for `DeviceQuantileDMatrix`.'''
|
||||
def __init__(self, max_bin, missing, nthread, silent,
|
||||
@ -518,8 +576,53 @@ class DeviceQuantileDMatrixDataHandler(DataHandler): # pylint: disable=abstract
|
||||
self.max_bin = max_bin
|
||||
super().__init__(missing, nthread, silent, meta, meta_type)
|
||||
|
||||
def handle_meta(self, label=None, weight=None, base_margin=None,
|
||||
group=None,
|
||||
label_lower_bound=None,
|
||||
label_upper_bound=None):
|
||||
self.label = label
|
||||
self.weight = weight
|
||||
self.base_margin = base_margin
|
||||
self.group = group
|
||||
self.label_lower_bound = label_lower_bound
|
||||
self.label_upper_bound = label_upper_bound
|
||||
|
||||
__device_quantile_dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name
|
||||
def handle_input(self, data, feature_names, feature_types):
|
||||
if not isinstance(data, DataIter):
|
||||
it = SingleBatchInternalIter(
|
||||
data, self.label, self.weight,
|
||||
self.base_margin, self.group,
|
||||
self.label_lower_bound, self.label_upper_bound)
|
||||
else:
|
||||
it = data
|
||||
reset_factory = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
reset_callback = reset_factory(it.reset_wrapper)
|
||||
next_factory = ctypes.CFUNCTYPE(
|
||||
ctypes.c_int,
|
||||
ctypes.c_void_p,
|
||||
)
|
||||
next_callback = next_factory(it.next_wrapper)
|
||||
handle = ctypes.c_void_p()
|
||||
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
None,
|
||||
it.proxy.handle,
|
||||
reset_callback,
|
||||
next_callback,
|
||||
ctypes.c_float(self.missing),
|
||||
ctypes.c_int(self.nthread),
|
||||
ctypes.c_int(self.max_bin),
|
||||
ctypes.byref(handle)
|
||||
)
|
||||
if it.exception:
|
||||
raise it.exception
|
||||
# delay check_call to throw intermediate exception first
|
||||
_check_call(ret)
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
__device_quantile_dmatrix_registry.register_handler_opaque(
|
||||
lambda x: isinstance(x, DataIter),
|
||||
DeviceQuantileDMatrixDataHandler)
|
||||
|
||||
|
||||
def get_device_quantile_dmatrix_data_handler(
|
||||
@ -549,19 +652,7 @@ class DeviceQuantileCudaArrayInterfaceHandler(
|
||||
data, '__array__'):
|
||||
import cupy # pylint: disable=import-error
|
||||
data = cupy.array(data, copy=False)
|
||||
|
||||
interface = data.__cuda_array_interface__
|
||||
if 'mask' in interface:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
|
||||
interface_str,
|
||||
ctypes.c_float(self.missing), ctypes.c_int(self.nthread),
|
||||
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
|
||||
return handle, feature_names, feature_types
|
||||
return super().handle_input(data, feature_names, feature_types)
|
||||
|
||||
|
||||
__device_quantile_dmatrix_registry.register_handler(
|
||||
@ -582,14 +673,7 @@ class DeviceQuantileCudaColumnarHandler(DeviceQuantileDMatrixDataHandler,
|
||||
"""Initialize Quantile Device DMatrix from columnar memory format."""
|
||||
data, feature_names, feature_types = self._maybe_cudf_dataframe(
|
||||
data, feature_names, feature_types)
|
||||
interfaces_str = _cudf_array_interfaces(data)
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
|
||||
interfaces_str,
|
||||
ctypes.c_float(self.missing), ctypes.c_int(self.nthread),
|
||||
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
|
||||
return handle, feature_names, feature_types
|
||||
return super().handle_input(data, feature_names, feature_types)
|
||||
|
||||
|
||||
__device_quantile_dmatrix_registry.register_handler(
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
#include "xgboost/learner.h"
|
||||
#include "c_api_error.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../data/device_dmatrix.h"
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
@ -31,28 +30,6 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
|
||||
bst_float missing, int nthread, int max_bin,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
std::string json_str{c_json_strs};
|
||||
data::CudfAdapter adapter(json_str);
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
bst_float missing, int nthread, int max_bin,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
std::string json_str{c_json_strs};
|
||||
data::CupyAdapter adapter(json_str);
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
|
||||
@ -201,7 +201,7 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
||||
|
||||
// Returns maximum row length
|
||||
template <typename AdapterBatchT>
|
||||
size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
|
||||
size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
||||
int device_idx, float missing) {
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
|
||||
@ -1,58 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file device_dmatrix.cu
|
||||
* \brief Device-memory version of DMatrix.
|
||||
*/
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/transform_output_iterator.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "../common/hist_util.h"
|
||||
#include "adapter.h"
|
||||
#include "device_adapter.cuh"
|
||||
#include "ellpack_page.cuh"
|
||||
#include "device_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
// Does not currently support metainfo as no on-device data source contains this
|
||||
// Current implementation assumes a single batch. More batches can
|
||||
// be supported in future. Does not currently support inferring row/column size
|
||||
template <typename AdapterT>
|
||||
DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin) {
|
||||
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
|
||||
auto& batch = adapter->Value();
|
||||
// Work out how many valid entries we have in each row
|
||||
dh::caching_device_vector<size_t> row_counts(adapter->NumRows() + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(),
|
||||
row_counts.size());
|
||||
size_t row_stride =
|
||||
GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing);
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
info_.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc),
|
||||
row_counts.begin(), row_counts.end());
|
||||
info_.num_col_ = adapter->NumColumns();
|
||||
info_.num_row_ = adapter->NumRows();
|
||||
|
||||
ellpack_page_.reset(new EllpackPage());
|
||||
*ellpack_page_->Impl() =
|
||||
EllpackPageImpl(adapter, missing, this->IsDense(), nthread, max_bin,
|
||||
row_counts_span, row_stride);
|
||||
|
||||
// Synchronise worker columns
|
||||
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
|
||||
}
|
||||
|
||||
#define DEVICE_DMARIX_SPECIALIZATION(__ADAPTER_T) \
|
||||
template DeviceDMatrix::DeviceDMatrix(__ADAPTER_T* adapter, float missing, \
|
||||
int nthread, int max_bin);
|
||||
|
||||
DEVICE_DMARIX_SPECIALIZATION(CudfAdapter);
|
||||
DEVICE_DMARIX_SPECIALIZATION(CupyAdapter);
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
@ -1,64 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file device_dmatrix.h
|
||||
* \brief Device-memory version of DMatrix.
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
#define XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "adapter.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
#include "simple_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
class DeviceDMatrix : public DMatrix {
|
||||
public:
|
||||
template <typename AdapterT>
|
||||
explicit DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin);
|
||||
|
||||
MetaInfo& Info() override { return info_; }
|
||||
|
||||
const MetaInfo& Info() const override { return info_; }
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
|
||||
bool EllpackExists() const override { return true; }
|
||||
bool SparsePageExists() const override { return false; }
|
||||
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
|
||||
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
BatchSet<SparsePage> GetRowBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||
}
|
||||
BatchSet<CSCPage> GetColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
|
||||
}
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
|
||||
}
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override {
|
||||
auto begin_iter = BatchIterator<EllpackPage>(
|
||||
new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||
return BatchSet<EllpackPage>(begin_iter);
|
||||
}
|
||||
|
||||
MetaInfo info_;
|
||||
// source data pointer.
|
||||
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
@ -93,6 +93,11 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
batches++;
|
||||
}
|
||||
|
||||
if (device < 0) { // error or empty
|
||||
this->page_.reset(new EllpackPage);
|
||||
return;
|
||||
}
|
||||
|
||||
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
|
||||
for (auto const& sketch : sketch_containers) {
|
||||
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
|
||||
@ -108,14 +113,23 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
this->info_.num_row_ = accumulated_rows;
|
||||
this->info_.num_nonzero_ = nnz;
|
||||
|
||||
// Construct the final ellpack page.
|
||||
auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows]() {
|
||||
if (!page_) {
|
||||
// Should be put inside the while loop to protect against empty batch. In
|
||||
// that case device id is invalid.
|
||||
page_.reset(new EllpackPage);
|
||||
*(page_->Impl()) = EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(),
|
||||
row_stride, accumulated_rows);
|
||||
*(page_->Impl()) =
|
||||
EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(), row_stride,
|
||||
accumulated_rows);
|
||||
}
|
||||
};
|
||||
|
||||
// Construct the final ellpack page.
|
||||
size_t offset = 0;
|
||||
iter.Reset();
|
||||
size_t n_batches_for_verification = 0;
|
||||
while (iter.Next()) {
|
||||
init_page();
|
||||
auto device = proxy->DeviceIdx();
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
auto rows = num_rows();
|
||||
@ -138,7 +152,10 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
|
||||
if (batches != 1) {
|
||||
this->info_.Extend(std::move(proxy->Info()), false);
|
||||
}
|
||||
n_batches_for_verification++;
|
||||
}
|
||||
CHECK_EQ(batches, n_batches_for_verification)
|
||||
<< "Different number of batches returned between 2 iterations";
|
||||
|
||||
if (batches == 1) {
|
||||
this->info_ = std::move(proxy->Info());
|
||||
|
||||
@ -1,149 +0,0 @@
|
||||
|
||||
// Copyright (c) 2019 by Contributors
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../../../src/data/ellpack_page.cuh"
|
||||
#include "../../../src/data/device_dmatrix.h"
|
||||
#include "../helpers.h"
|
||||
#include <thrust/device_vector.h>
|
||||
#include "../../../src/data/device_adapter.cuh"
|
||||
#include "../../../src/gbm/gbtree_model.h"
|
||||
#include "../common/test_hist_util.h"
|
||||
#include "../../../src/common/compressed_iterator.h"
|
||||
#include "../../../src/common/math.h"
|
||||
#include "test_array_interface.h"
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
TEST(DeviceDMatrix, RowMajor) {
|
||||
int num_rows = 1000;
|
||||
int num_columns = 50;
|
||||
auto x = common::GenerateRandom(num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
|
||||
|
||||
data::DeviceDMatrix dmat(&adapter,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1, 256);
|
||||
|
||||
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
|
||||
auto impl = batch.Impl();
|
||||
common::CompressedIterator<uint32_t> iterator(
|
||||
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
|
||||
for(auto i = 0ull; i < x.size(); i++)
|
||||
{
|
||||
int column_idx = i % num_columns;
|
||||
EXPECT_EQ(impl->Cuts().SearchBin(x[i], column_idx), iterator[i]);
|
||||
}
|
||||
EXPECT_EQ(dmat.Info().num_col_, num_columns);
|
||||
EXPECT_EQ(dmat.Info().num_row_, num_rows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, num_rows * num_columns);
|
||||
|
||||
}
|
||||
|
||||
TEST(DeviceDMatrix, RowMajorMissing) {
|
||||
const float kMissing = std::numeric_limits<float>::quiet_NaN();
|
||||
int num_rows = 10;
|
||||
int num_columns = 2;
|
||||
auto x = common::GenerateRandom(num_rows, num_columns);
|
||||
x[1] = kMissing;
|
||||
x[5] = kMissing;
|
||||
x[6] = kMissing;
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
|
||||
|
||||
data::DeviceDMatrix dmat(&adapter, kMissing, 1, 256);
|
||||
|
||||
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
|
||||
auto impl = batch.Impl();
|
||||
common::CompressedIterator<uint32_t> iterator(
|
||||
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
|
||||
EXPECT_EQ(iterator[1], impl->GetDeviceAccessor(0).NullValue());
|
||||
EXPECT_EQ(iterator[5], impl->GetDeviceAccessor(0).NullValue());
|
||||
// null values get placed after valid values in a row
|
||||
EXPECT_EQ(iterator[7], impl->GetDeviceAccessor(0).NullValue());
|
||||
EXPECT_EQ(dmat.Info().num_col_, num_columns);
|
||||
EXPECT_EQ(dmat.Info().num_row_, num_rows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, num_rows*num_columns-3);
|
||||
|
||||
}
|
||||
|
||||
TEST(DeviceDMatrix, ColumnMajor) {
|
||||
constexpr size_t kRows{100};
|
||||
std::vector<Json> columns;
|
||||
thrust::device_vector<double> d_data_0(kRows);
|
||||
thrust::device_vector<uint32_t> d_data_1(kRows);
|
||||
|
||||
columns.emplace_back(GenerateDenseColumn<double>("<f8", kRows, &d_data_0));
|
||||
columns.emplace_back(GenerateDenseColumn<uint32_t>("<u4", kRows, &d_data_1));
|
||||
|
||||
Json column_arr{columns};
|
||||
|
||||
std::string str;
|
||||
Json::Dump(column_arr, &str);
|
||||
|
||||
data::CudfAdapter adapter(str);
|
||||
data::DeviceDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||
-1, 256);
|
||||
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
|
||||
auto impl = batch.Impl();
|
||||
common::CompressedIterator<uint32_t> iterator(
|
||||
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
|
||||
|
||||
for (auto i = 0ull; i < kRows; i++) {
|
||||
for (auto j = 0ull; j < columns.size(); j++) {
|
||||
if (j == 0) {
|
||||
EXPECT_EQ(iterator[i * 2 + j], impl->Cuts().SearchBin(d_data_0[i], j));
|
||||
} else {
|
||||
EXPECT_EQ(iterator[i * 2 + j], impl->Cuts().SearchBin(d_data_1[i], j));
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(dmat.Info().num_col_, 2);
|
||||
EXPECT_EQ(dmat.Info().num_row_, kRows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, kRows*2);
|
||||
|
||||
}
|
||||
|
||||
// Test equivalence with simple DMatrix
|
||||
TEST(DeviceDMatrix, Equivalent) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = common::GenerateRandom(num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto dmat = common::GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
|
||||
data::DeviceDMatrix device_dmat(
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, num_bins);
|
||||
|
||||
const auto &batch = *dmat->GetBatches<EllpackPage>({0, num_bins}).begin();
|
||||
const auto &device_dmat_batch =
|
||||
*device_dmat.GetBatches<EllpackPage>({0, num_bins}).begin();
|
||||
|
||||
ASSERT_EQ(batch.Impl()->Cuts().Values(), device_dmat_batch.Impl()->Cuts().Values());
|
||||
ASSERT_EQ(batch.Impl()->gidx_buffer.HostVector(),
|
||||
device_dmat_batch.Impl()->gidx_buffer.HostVector());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DeviceDMatrix, IsDense) {
|
||||
int num_bins = 16;
|
||||
auto test = [num_bins] (float sparsity) {
|
||||
HostDeviceVector<float> data;
|
||||
std::string interface_str = RandomDataGenerator{10, 10, sparsity}
|
||||
.Device(0).GenerateArrayInterface(&data);
|
||||
data::CupyAdapter x{interface_str};
|
||||
std::unique_ptr<data::DeviceDMatrix> device_dmat{ new data::DeviceDMatrix(
|
||||
&x, std::numeric_limits<float>::quiet_NaN(), 1, num_bins) };
|
||||
if (sparsity == 0.0) {
|
||||
ASSERT_TRUE(device_dmat->IsDense()) << sparsity;
|
||||
} else {
|
||||
ASSERT_FALSE(device_dmat->IsDense());
|
||||
}
|
||||
};
|
||||
test(0.0);
|
||||
test(0.1);
|
||||
}
|
||||
@ -54,6 +54,7 @@ void TestTrainingPrediction(size_t rows, size_t bins,
|
||||
learner->SetParam("objective", "multi:softprob");
|
||||
learner->SetParam("num_feature", std::to_string(kCols));
|
||||
learner->SetParam("num_class", std::to_string(kClasses));
|
||||
learner->SetParam("max_bin", std::to_string(bins));
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < kIters; ++i) {
|
||||
|
||||
@ -170,3 +170,83 @@ Arrow specification.'''
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_cudf_metainfo_device_dmatrix(self):
|
||||
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
|
||||
|
||||
|
||||
class IterForDMatrixTest(xgb.core.DataIter):
|
||||
'''A data iterator for XGBoost DMatrix.
|
||||
|
||||
`reset` and `next` are required for any data iterator, other functions here
|
||||
are utilites for demonstration's purpose.
|
||||
|
||||
'''
|
||||
ROWS_PER_BATCH = 100 # data is splited by rows
|
||||
BATCHES = 16
|
||||
|
||||
def __init__(self):
|
||||
'''Generate some random data for demostration.
|
||||
|
||||
Actual data can be anything that is currently supported by XGBoost.
|
||||
'''
|
||||
import cudf
|
||||
self.rows = self.ROWS_PER_BATCH
|
||||
rng = np.random.RandomState(1994)
|
||||
self._data = [
|
||||
cudf.DataFrame(
|
||||
{'a': rng.randn(self.ROWS_PER_BATCH),
|
||||
'b': rng.randn(self.ROWS_PER_BATCH)})] * self.BATCHES
|
||||
self._labels = [rng.randn(self.rows)] * self.BATCHES
|
||||
|
||||
self.it = 0 # set iterator to 0
|
||||
super().__init__()
|
||||
|
||||
def as_array(self):
|
||||
import cudf
|
||||
return cudf.concat(self._data)
|
||||
|
||||
def as_array_labels(self):
|
||||
return np.concatenate(self._labels)
|
||||
|
||||
def data(self):
|
||||
'''Utility function for obtaining current batch of data.'''
|
||||
return self._data[self.it]
|
||||
|
||||
def labels(self):
|
||||
'''Utility function for obtaining current batch of label.'''
|
||||
return self._labels[self.it]
|
||||
|
||||
def reset(self):
|
||||
'''Reset the iterator'''
|
||||
self.it = 0
|
||||
|
||||
def next(self, input_data):
|
||||
'''Yield next batch of data'''
|
||||
if self.it == len(self._data):
|
||||
# Return 0 when there's no more batch.
|
||||
return 0
|
||||
input_data(data=self.data(), label=self.labels())
|
||||
self.it += 1
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_from_cudf_iter():
|
||||
rounds = 100
|
||||
it = IterForDMatrixTest()
|
||||
|
||||
# Use iterator
|
||||
m_it = xgb.DeviceQuantileDMatrix(it)
|
||||
reg_with_it = xgb.train({'tree_method': 'gpu_hist'}, m_it,
|
||||
num_boost_round=rounds)
|
||||
predict_with_it = reg_with_it.predict(m_it)
|
||||
|
||||
# Without using iterator
|
||||
m = xgb.DMatrix(it.as_array(), it.as_array_labels())
|
||||
|
||||
assert m_it.num_col() == m.num_col()
|
||||
assert m_it.num_row() == m.num_row()
|
||||
|
||||
reg = xgb.train({'tree_method': 'gpu_hist'}, m,
|
||||
num_boost_round=rounds)
|
||||
predict = reg.predict(m)
|
||||
|
||||
np.testing.assert_allclose(predict_with_it, predict)
|
||||
|
||||
11
tests/python-gpu/test_gpu_demos.py
Normal file
11
tests/python-gpu/test_gpu_demos.py
Normal file
@ -0,0 +1,11 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
sys.path.append("tests/python")
|
||||
import test_demos as td # noqa
|
||||
|
||||
|
||||
def test_data_iterator():
|
||||
script = os.path.join(td.PYTHON_DEMO_DIR, 'data_iterator.py')
|
||||
cmd = ['python', script]
|
||||
subprocess.check_call(cmd)
|
||||
Loading…
x
Reference in New Issue
Block a user