diff --git a/demo/guide-python/data_iterator.py b/demo/guide-python/data_iterator.py new file mode 100644 index 000000000..4f4b08c0f --- /dev/null +++ b/demo/guide-python/data_iterator.py @@ -0,0 +1,109 @@ +'''A demo for defining data iterator. + +The demo that defines a customized iterator for passing batches of data into +`xgboost.DeviceQuantileDMatrix` and use this `DeviceQuantileDMatrix` for +training. The feature is used primarily designed to reduce the required GPU +memory for training on distributed environment. + +Aftering going through the demo, one might ask why don't we use more native +Python iterator? That's because XGBoost requires a `reset` function, while +using `itertools.tee` might incur significant memory usage according to: + + https://docs.python.org/3/library/itertools.html#itertools.tee. + +''' + +import xgboost +import cupy +import numpy + +COLS = 64 +ROWS_PER_BATCH = 1000 # data is splited by rows +BATCHES = 32 + + +class IterForDMatrixDemo(xgboost.core.DataIter): + '''A data iterator for XGBoost DMatrix. + + `reset` and `next` are required for any data iterator, other functions here + are utilites for demonstration's purpose. + + ''' + def __init__(self): + '''Generate some random data for demostration. + + Actual data can be anything that is currently supported by XGBoost. + ''' + self.rows = ROWS_PER_BATCH + self.cols = COLS + rng = cupy.random.RandomState(1994) + self._data = [rng.randn(self.rows, self.cols)] * BATCHES + self._labels = [rng.randn(self.rows)] * BATCHES + self._weights = [rng.randn(self.rows)] * BATCHES + + self.it = 0 # set iterator to 0 + super().__init__() + + def as_array(self): + return cupy.concatenate(self._data) + + def as_array_labels(self): + return cupy.concatenate(self._labels) + + def as_array_weights(self): + return cupy.concatenate(self._weights) + + def data(self): + '''Utility function for obtaining current batch of data.''' + return self._data[self.it] + + def labels(self): + '''Utility function for obtaining current batch of label.''' + return self._labels[self.it] + + def weights(self): + return self._weights[self.it] + + def reset(self): + '''Reset the iterator''' + self.it = 0 + + def next(self, input_data): + '''Yield next batch of data.''' + if self.it == len(self._data): + # Return 0 when there's no more batch. + return 0 + input_data(data=self.data(), label=self.labels(), + weight=self.weights()) + self.it += 1 + return 1 + + +def main(): + rounds = 100 + it = IterForDMatrixDemo() + + # Use iterator, must be `DeviceQuantileDMatrix` + m_with_it = xgboost.DeviceQuantileDMatrix(it) + + # Use regular DMatrix. + m = xgboost.DMatrix(it.as_array(), it.as_array_labels(), + weight=it.as_array_weights()) + + assert m_with_it.num_col() == m.num_col() + assert m_with_it.num_row() == m.num_row() + + reg_with_it = xgboost.train({'tree_method': 'gpu_hist'}, m_with_it, + num_boost_round=rounds) + predict_with_it = reg_with_it.predict(m_with_it) + + reg = xgboost.train({'tree_method': 'gpu_hist'}, m, + num_boost_round=rounds) + predict = reg.predict(m) + + numpy.testing.assert_allclose(predict_with_it, predict, + rtol=1e6) + + +if __name__ == '__main__': + main() diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 687e47b2c..4a1208f66 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -300,6 +300,99 @@ def _cudf_array_interfaces(df): return interfaces_str +class DataIter: + '''The interface for user defined data iterator. Currently is only + supported by Device DMatrix. + + Parameters + ---------- + + rows : int + Total number of rows combining all batches. + cols : int + Number of columns for each batch. + ''' + def __init__(self): + proxy_handle = ctypes.c_void_p() + _check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(proxy_handle))) + self._handle = DeviceQuantileDMatrix(proxy_handle) + self.exception = None + + @property + def proxy(self): + '''Handler of DMatrix proxy.''' + return self._handle + + def reset_wrapper(self, this): # pylint: disable=unused-argument + '''A wrapper for user defined `reset` function.''' + self.reset() + + def next_wrapper(self, this): # pylint: disable=unused-argument + '''A wrapper for user defined `next` function. + + `this` is not used in Python. ctypes can handle `self` of a Python + member function automatically when converting a it to c function + pointer. + + ''' + if self.exception is not None: + return 0 + + def data_handle(data, label=None, weight=None, base_margin=None, + group=None, + label_lower_bound=None, label_upper_bound=None): + if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): + # pylint: disable=protected-access + self.proxy._set_data_from_cuda_columnar(data) + elif lazy_isinstance(data, 'cudf.core.series', 'Series'): + # pylint: disable=protected-access + self.proxy._set_data_from_cuda_columnar(data) + elif lazy_isinstance(data, 'cupy.core.core', 'ndarray'): + # pylint: disable=protected-access + self.proxy._set_data_from_cuda_interface(data) + else: + raise TypeError( + 'Value type is not supported for data iterator:' + + str(type(self._handle)), type(data)) + self.proxy.set_info(label=label, weight=weight, + base_margin=base_margin, + group=group, + label_lower_bound=label_lower_bound, + label_upper_bound=label_upper_bound) + try: + # Deffer the exception in order to return 0 and stop the iteration. + # Exception inside a ctype callback function has no effect except + # for printing to stderr (doesn't stop the execution). + ret = self.next(data_handle) # pylint: disable=not-callable + except Exception as e: # pylint: disable=broad-except + tb = sys.exc_info()[2] + print('Got an exception in Python') + self.exception = e.with_traceback(tb) + return 0 + return ret + + def reset(self): + '''Reset the data iterator. Prototype for user defined function.''' + raise NotImplementedError() + + def next(self, input_data): + '''Set the next batch of data. + + Parameters + ---------- + + data_handle: callable + A function with same data fields like `data`, `label` with + `xgboost.DMatrix`. + + Returns + ------- + 0 if there's no more batch, otherwise 1. + + ''' + raise NotImplementedError() + + class DMatrix: # pylint: disable=too-many-instance-attributes """Data Matrix used in XGBoost. @@ -361,36 +454,65 @@ class DMatrix: # pylint: disable=too-many-instance-attributes self.handle = None return - handler = self.get_data_handler(data) + handler = self._get_data_handler(data) + can_handle_meta = False if handler is None: data = _convert_unknown_data(data, None) - handler = self.get_data_handler(data) + handler = self._get_data_handler(data) + try: + handler.handle_meta(label, weight, base_margin) + can_handle_meta = True + except NotImplementedError: + can_handle_meta = False + self.handle, feature_names, feature_types = handler.handle_input( data, feature_names, feature_types) assert self.handle, 'Failed to construct a DMatrix.' - if label is not None: - self.set_label(label) - if weight is not None: - self.set_weight(weight) - if base_margin is not None: - self.set_base_margin(base_margin) + if not can_handle_meta: + self.set_info(label, weight, base_margin) self.feature_names = feature_names self.feature_types = feature_types - def get_data_handler(self, data, meta=None, meta_type=None): + def _get_data_handler(self, data, meta=None, meta_type=None): '''Get data handler for this DMatrix class.''' from .data import get_dmatrix_data_handler handler = get_dmatrix_data_handler( data, self.missing, self.nthread, self.silent, meta, meta_type) return handler + # pylint: disable=no-self-use + def _get_meta_handler(self, data, meta, meta_type): + from .data import get_dmatrix_meta_handler + handler = get_dmatrix_meta_handler( + data, meta, meta_type) + return handler + def __del__(self): if hasattr(self, "handle") and self.handle: _check_call(_LIB.XGDMatrixFree(self.handle)) self.handle = None + def set_info(self, + label=None, weight=None, base_margin=None, + group=None, + label_lower_bound=None, + label_upper_bound=None): + '''Set meta info for DMatrix.''' + if label is not None: + self.set_label(label) + if weight is not None: + self.set_weight(weight) + if base_margin is not None: + self.set_base_margin(base_margin) + if group is not None: + self.set_group(group) + if label_lower_bound is not None: + self.set_float_info('label_lower_bound', label_lower_bound) + if label_upper_bound is not None: + self.set_float_info('label_upper_bound', label_upper_bound) + def get_float_info(self, field): """Get float property from the DMatrix. @@ -447,10 +569,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes if isinstance(data, np.ndarray): self.set_float_info_npy2d(field, data) return - handler = self.get_data_handler(data, field, np.float32) - if handler is None: - data = _convert_unknown_data(data, field, np.float32) - handler = self.get_data_handler(data, field, np.float32) + handler = self._get_data_handler(data, field, np.float32) + assert handler data, _, _ = handler.transform(data) c_data = c_array(ctypes.c_float, data) _check_call(_LIB.XGDMatrixSetFloatInfo(self.handle, @@ -470,7 +590,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes data: numpy array The array of data to be set """ - data, _, _ = self.get_data_handler( + data, _, _ = self._get_meta_handler( data, field, np.float32).transform(data) c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) _check_call(_LIB.XGDMatrixSetFloatInfo(self.handle, @@ -489,7 +609,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes data: numpy array The array of data to be set """ - data, _, _ = self.get_data_handler( + data, _, _ = self._get_data_handler( data, field, 'uint32').transform(data) _check_call(_LIB.XGDMatrixSetUIntInfo(self.handle, c_str(field), @@ -803,11 +923,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes class DeviceQuantileDMatrix(DMatrix): - """Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do not - use this for test/validation tasks as some information may be lost in quantisation. This - DMatrix is primarily designed to save memory in training from device memory inputs by - avoiding intermediate storage. Implementation does not currently consider weights in - quantisation process(unlike DMatrix). Set max_bin to control the number of bins during + """Device memory Data Matrix used in XGBoost for training with + tree_method='gpu_hist'. Do not use this for test/validation tasks as some + information may be lost in quantisation. This DMatrix is primarily designed + to save memory in training from device memory inputs by avoiding + intermediate storage. Set max_bin to control the number of bins during quantisation. You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack. @@ -823,6 +943,9 @@ class DeviceQuantileDMatrix(DMatrix): feature_types=None, nthread=None, max_bin=256): self.max_bin = max_bin + if isinstance(data, ctypes.c_void_p): + self.handle = data + return super().__init__(data, label=label, weight=weight, base_margin=base_margin, missing=missing, @@ -831,11 +954,32 @@ class DeviceQuantileDMatrix(DMatrix): feature_types=feature_types, nthread=nthread) - def get_data_handler(self, data, meta=None, meta_type=None): + def _get_data_handler(self, data, meta=None, meta_type=None): from .data import get_device_quantile_dmatrix_data_handler return get_device_quantile_dmatrix_data_handler( data, self.max_bin, self.missing, self.nthread, self.silent) + def _set_data_from_cuda_interface(self, data): + '''Set data from CUDA array interface.''' + interface = data.__cuda_array_interface__ + interface_str = bytes(json.dumps(interface, indent=2), 'utf-8') + _check_call( + _LIB.XGDeviceQuantileDMatrixSetDataCudaArrayInterface( + self.handle, + interface_str + ) + ) + + def _set_data_from_cuda_columnar(self, data): + '''Set data from CUDA columnar format.1''' + interfaces_str = _cudf_array_interfaces(data) + _check_call( + _LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar( + self.handle, + interfaces_str + ) + ) + class Booster(object): # pylint: disable=too-many-public-methods diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 170ce81c0..955802ec8 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -1,4 +1,4 @@ -# pylint: disable=too-many-arguments, no-self-use +# pylint: disable=too-many-arguments, no-self-use, too-many-instance-attributes '''Data dispatching for DMatrix.''' import ctypes import abc @@ -8,6 +8,7 @@ import warnings import numpy as np from .core import c_array, _LIB, _check_call, c_str, _cudf_array_interfaces +from .core import DataIter from .compat import lazy_isinstance, STRING_TYPES, os_fspath, os_PathLike c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name @@ -23,6 +24,18 @@ class DataHandler(abc.ABC): self.meta = meta self.meta_type = meta_type + def handle_meta(self, label=None, weight=None, base_margin=None, + group=None, + label_lower_bound=None, + label_upper_bound=None): + '''Handle meta data when the DMatrix type can not defer setting meta + data after construction. Example is `DeviceQuantileDMatrix` + which requires weight to be presented before digesting + data. + + ''' + raise NotImplementedError() + def _warn_unused_missing(self, data): if not (np.isnan(np.nan) or None): warnings.warn( @@ -116,6 +129,14 @@ def get_dmatrix_data_handler(data, missing, nthread, silent, return handler(missing, nthread, silent, meta, meta_type) +def get_dmatrix_meta_handler(data, meta, meta_type): + '''Get handler for meta instead of data.''' + handler = __dmatrix_registry.get_handler(data) + if handler is None: + return None + return handler(None, 0, True, meta, meta_type) + + class FileHandler(DataHandler): '''Handler of path like input.''' def handle_input(self, data, feature_names, feature_types): @@ -511,6 +532,43 @@ __dmatrix_registry.register_handler_opaque( DLPackHandler) +class SingleBatchInternalIter(DataIter): + '''An iterator for single batch data to help creating device DMatrix. + Transforming input directly to histogram with normal single batch data API + can not access weight for sketching. So this iterator acts as a staging + area for meta info. + + ''' + def __init__(self, data, label, weight, base_margin, group, + label_lower_bound, label_upper_bound): + self.data = data + self.label = label + self.weight = weight + self.base_margin = base_margin + self.group = group + self.label_lower_bound = label_lower_bound + self.label_upper_bound = label_upper_bound + self.it = 0 # pylint: disable=invalid-name + super().__init__() + + def next(self, input_data): + if self.it == 1: + return 0 + self.it += 1 + input_data(data=self.data, label=self.label, + weight=self.weight, base_margin=self.base_margin, + group=self.group, + label_lower_bound=self.label_lower_bound, + label_upper_bound=self.label_upper_bound) + return 1 + + def reset(self): + self.it = 0 + + +__device_quantile_dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name + + class DeviceQuantileDMatrixDataHandler(DataHandler): # pylint: disable=abstract-method '''Base class of data handler for `DeviceQuantileDMatrix`.''' def __init__(self, max_bin, missing, nthread, silent, @@ -518,8 +576,53 @@ class DeviceQuantileDMatrixDataHandler(DataHandler): # pylint: disable=abstract self.max_bin = max_bin super().__init__(missing, nthread, silent, meta, meta_type) + def handle_meta(self, label=None, weight=None, base_margin=None, + group=None, + label_lower_bound=None, + label_upper_bound=None): + self.label = label + self.weight = weight + self.base_margin = base_margin + self.group = group + self.label_lower_bound = label_lower_bound + self.label_upper_bound = label_upper_bound -__device_quantile_dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name + def handle_input(self, data, feature_names, feature_types): + if not isinstance(data, DataIter): + it = SingleBatchInternalIter( + data, self.label, self.weight, + self.base_margin, self.group, + self.label_lower_bound, self.label_upper_bound) + else: + it = data + reset_factory = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + reset_callback = reset_factory(it.reset_wrapper) + next_factory = ctypes.CFUNCTYPE( + ctypes.c_int, + ctypes.c_void_p, + ) + next_callback = next_factory(it.next_wrapper) + handle = ctypes.c_void_p() + ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback( + None, + it.proxy.handle, + reset_callback, + next_callback, + ctypes.c_float(self.missing), + ctypes.c_int(self.nthread), + ctypes.c_int(self.max_bin), + ctypes.byref(handle) + ) + if it.exception: + raise it.exception + # delay check_call to throw intermediate exception first + _check_call(ret) + return handle, feature_names, feature_types + + +__device_quantile_dmatrix_registry.register_handler_opaque( + lambda x: isinstance(x, DataIter), + DeviceQuantileDMatrixDataHandler) def get_device_quantile_dmatrix_data_handler( @@ -549,19 +652,7 @@ class DeviceQuantileCudaArrayInterfaceHandler( data, '__array__'): import cupy # pylint: disable=import-error data = cupy.array(data, copy=False) - - interface = data.__cuda_array_interface__ - if 'mask' in interface: - interface['mask'] = interface['mask'].__cuda_array_interface__ - interface_str = bytes(json.dumps(interface, indent=2), 'utf-8') - - handle = ctypes.c_void_p() - _check_call( - _LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface( - interface_str, - ctypes.c_float(self.missing), ctypes.c_int(self.nthread), - ctypes.c_int(self.max_bin), ctypes.byref(handle))) - return handle, feature_names, feature_types + return super().handle_input(data, feature_names, feature_types) __device_quantile_dmatrix_registry.register_handler( @@ -582,14 +673,7 @@ class DeviceQuantileCudaColumnarHandler(DeviceQuantileDMatrixDataHandler, """Initialize Quantile Device DMatrix from columnar memory format.""" data, feature_names, feature_types = self._maybe_cudf_dataframe( data, feature_names, feature_types) - interfaces_str = _cudf_array_interfaces(data) - handle = ctypes.c_void_p() - _check_call( - _LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns( - interfaces_str, - ctypes.c_float(self.missing), ctypes.c_int(self.nthread), - ctypes.c_int(self.max_bin), ctypes.byref(handle))) - return handle, feature_names, feature_types + return super().handle_input(data, feature_names, feature_types) __device_quantile_dmatrix_registry.register_handler( diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index f1a486d8c..5af04894d 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -4,7 +4,6 @@ #include "xgboost/learner.h" #include "c_api_error.h" #include "../data/device_adapter.cuh" -#include "../data/device_dmatrix.h" using namespace xgboost; // NOLINT @@ -31,28 +30,6 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs, API_END(); } -XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs, - bst_float missing, int nthread, int max_bin, - DMatrixHandle* out) { - API_BEGIN(); - std::string json_str{c_json_strs}; - data::CudfAdapter adapter(json_str); - *out = - new std::shared_ptr(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin)); - API_END(); -} - -XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterface(char const* c_json_strs, - bst_float missing, int nthread, int max_bin, - DMatrixHandle* out) { - API_BEGIN(); - std::string json_str{c_json_strs}; - data::CupyAdapter adapter(json_str); - *out = - new std::shared_ptr(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin)); - API_END(); -} - // A hidden API as cache id is not being supported yet. XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle, char const* c_json_strs, diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 5f6a3b6cc..10ae6ba3d 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -201,7 +201,7 @@ class CupyAdapter : public detail::SingleBatchDataIter { // Returns maximum row length template -size_t GetRowCounts(const AdapterBatchT& batch, common::Span offset, +size_t GetRowCounts(const AdapterBatchT batch, common::Span offset, int device_idx, float missing) { IsValidFunctor is_valid(missing); // Count elements per row diff --git a/src/data/device_dmatrix.cu b/src/data/device_dmatrix.cu deleted file mode 100644 index d11d01b16..000000000 --- a/src/data/device_dmatrix.cu +++ /dev/null @@ -1,58 +0,0 @@ -/*! - * Copyright 2020 by Contributors - * \file device_dmatrix.cu - * \brief Device-memory version of DMatrix. - */ - -#include -#include -#include -#include -#include -#include -#include -#include "../common/hist_util.h" -#include "adapter.h" -#include "device_adapter.cuh" -#include "ellpack_page.cuh" -#include "device_dmatrix.h" - -namespace xgboost { -namespace data { -// Does not currently support metainfo as no on-device data source contains this -// Current implementation assumes a single batch. More batches can -// be supported in future. Does not currently support inferring row/column size -template -DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin) { - dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx())); - auto& batch = adapter->Value(); - // Work out how many valid entries we have in each row - dh::caching_device_vector row_counts(adapter->NumRows() + 1, 0); - common::Span row_counts_span(row_counts.data().get(), - row_counts.size()); - size_t row_stride = - GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing); - - dh::XGBCachingDeviceAllocator alloc; - info_.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc), - row_counts.begin(), row_counts.end()); - info_.num_col_ = adapter->NumColumns(); - info_.num_row_ = adapter->NumRows(); - - ellpack_page_.reset(new EllpackPage()); - *ellpack_page_->Impl() = - EllpackPageImpl(adapter, missing, this->IsDense(), nthread, max_bin, - row_counts_span, row_stride); - - // Synchronise worker columns - rabit::Allreduce(&info_.num_col_, 1); -} - -#define DEVICE_DMARIX_SPECIALIZATION(__ADAPTER_T) \ - template DeviceDMatrix::DeviceDMatrix(__ADAPTER_T* adapter, float missing, \ - int nthread, int max_bin); - -DEVICE_DMARIX_SPECIALIZATION(CudfAdapter); -DEVICE_DMARIX_SPECIALIZATION(CupyAdapter); -} // namespace data -} // namespace xgboost diff --git a/src/data/device_dmatrix.h b/src/data/device_dmatrix.h deleted file mode 100644 index 781461baa..000000000 --- a/src/data/device_dmatrix.h +++ /dev/null @@ -1,64 +0,0 @@ -/*! - * Copyright 2020 by Contributors - * \file device_dmatrix.h - * \brief Device-memory version of DMatrix. - */ -#ifndef XGBOOST_DATA_DEVICE_DMATRIX_H_ -#define XGBOOST_DATA_DEVICE_DMATRIX_H_ - -#include -#include - -#include - -#include "adapter.h" -#include "simple_batch_iterator.h" -#include "simple_dmatrix.h" - -namespace xgboost { -namespace data { - -class DeviceDMatrix : public DMatrix { - public: - template - explicit DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin); - - MetaInfo& Info() override { return info_; } - - const MetaInfo& Info() const override { return info_; } - - bool SingleColBlock() const override { return true; } - - bool EllpackExists() const override { return true; } - bool SparsePageExists() const override { return false; } - DMatrix *Slice(common::Span ridxs) override { - LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix."; - return nullptr; - } - - private: - BatchSet GetRowBatches() override { - LOG(FATAL) << "Not implemented."; - return BatchSet(BatchIterator(nullptr)); - } - BatchSet GetColumnBatches() override { - LOG(FATAL) << "Not implemented."; - return BatchSet(BatchIterator(nullptr)); - } - BatchSet GetSortedColumnBatches() override { - LOG(FATAL) << "Not implemented."; - return BatchSet(BatchIterator(nullptr)); - } - BatchSet GetEllpackBatches(const BatchParam& param) override { - auto begin_iter = BatchIterator( - new SimpleBatchIteratorImpl(ellpack_page_.get())); - return BatchSet(begin_iter); - } - - MetaInfo info_; - // source data pointer. - std::unique_ptr ellpack_page_; -}; -} // namespace data -} // namespace xgboost -#endif // XGBOOST_DATA_DEVICE_DMATRIX_H_ diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index 5b953cf45..2e9f97c88 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -93,6 +93,11 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin batches++; } + if (device < 0) { // error or empty + this->page_.reset(new EllpackPage); + return; + } + common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device); for (auto const& sketch : sketch_containers) { final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); @@ -108,14 +113,23 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin this->info_.num_row_ = accumulated_rows; this->info_.num_nonzero_ = nnz; - // Construct the final ellpack page. - page_.reset(new EllpackPage); - *(page_->Impl()) = EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(), - row_stride, accumulated_rows); + auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows]() { + if (!page_) { + // Should be put inside the while loop to protect against empty batch. In + // that case device id is invalid. + page_.reset(new EllpackPage); + *(page_->Impl()) = + EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(), row_stride, + accumulated_rows); + } + }; + // Construct the final ellpack page. size_t offset = 0; iter.Reset(); + size_t n_batches_for_verification = 0; while (iter.Next()) { + init_page(); auto device = proxy->DeviceIdx(); dh::safe_cuda(cudaSetDevice(device)); auto rows = num_rows(); @@ -138,7 +152,10 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin if (batches != 1) { this->info_.Extend(std::move(proxy->Info()), false); } + n_batches_for_verification++; } + CHECK_EQ(batches, n_batches_for_verification) + << "Different number of batches returned between 2 iterations"; if (batches == 1) { this->info_ = std::move(proxy->Info()); diff --git a/tests/cpp/data/test_device_dmatrix.cu b/tests/cpp/data/test_device_dmatrix.cu deleted file mode 100644 index 7e0574c2e..000000000 --- a/tests/cpp/data/test_device_dmatrix.cu +++ /dev/null @@ -1,149 +0,0 @@ - -// Copyright (c) 2019 by Contributors -#include -#include -#include "../../../src/data/adapter.h" -#include "../../../src/data/ellpack_page.cuh" -#include "../../../src/data/device_dmatrix.h" -#include "../helpers.h" -#include -#include "../../../src/data/device_adapter.cuh" -#include "../../../src/gbm/gbtree_model.h" -#include "../common/test_hist_util.h" -#include "../../../src/common/compressed_iterator.h" -#include "../../../src/common/math.h" -#include "test_array_interface.h" -using namespace xgboost; // NOLINT - -TEST(DeviceDMatrix, RowMajor) { - int num_rows = 1000; - int num_columns = 50; - auto x = common::GenerateRandom(num_rows, num_columns); - auto x_device = thrust::device_vector(x); - auto adapter = common::AdapterFromData(x_device, num_rows, num_columns); - - data::DeviceDMatrix dmat(&adapter, - std::numeric_limits::quiet_NaN(), 1, 256); - - auto &batch = *dmat.GetBatches({0, 256, 0}).begin(); - auto impl = batch.Impl(); - common::CompressedIterator iterator( - impl->gidx_buffer.HostVector().data(), impl->NumSymbols()); - for(auto i = 0ull; i < x.size(); i++) - { - int column_idx = i % num_columns; - EXPECT_EQ(impl->Cuts().SearchBin(x[i], column_idx), iterator[i]); - } - EXPECT_EQ(dmat.Info().num_col_, num_columns); - EXPECT_EQ(dmat.Info().num_row_, num_rows); - EXPECT_EQ(dmat.Info().num_nonzero_, num_rows * num_columns); - -} - -TEST(DeviceDMatrix, RowMajorMissing) { - const float kMissing = std::numeric_limits::quiet_NaN(); - int num_rows = 10; - int num_columns = 2; - auto x = common::GenerateRandom(num_rows, num_columns); - x[1] = kMissing; - x[5] = kMissing; - x[6] = kMissing; - auto x_device = thrust::device_vector(x); - auto adapter = common::AdapterFromData(x_device, num_rows, num_columns); - - data::DeviceDMatrix dmat(&adapter, kMissing, 1, 256); - - auto &batch = *dmat.GetBatches({0, 256, 0}).begin(); - auto impl = batch.Impl(); - common::CompressedIterator iterator( - impl->gidx_buffer.HostVector().data(), impl->NumSymbols()); - EXPECT_EQ(iterator[1], impl->GetDeviceAccessor(0).NullValue()); - EXPECT_EQ(iterator[5], impl->GetDeviceAccessor(0).NullValue()); - // null values get placed after valid values in a row - EXPECT_EQ(iterator[7], impl->GetDeviceAccessor(0).NullValue()); - EXPECT_EQ(dmat.Info().num_col_, num_columns); - EXPECT_EQ(dmat.Info().num_row_, num_rows); - EXPECT_EQ(dmat.Info().num_nonzero_, num_rows*num_columns-3); - -} - -TEST(DeviceDMatrix, ColumnMajor) { - constexpr size_t kRows{100}; - std::vector columns; - thrust::device_vector d_data_0(kRows); - thrust::device_vector d_data_1(kRows); - - columns.emplace_back(GenerateDenseColumn("("::quiet_NaN(), - -1, 256); - auto &batch = *dmat.GetBatches({0, 256, 0}).begin(); - auto impl = batch.Impl(); - common::CompressedIterator iterator( - impl->gidx_buffer.HostVector().data(), impl->NumSymbols()); - - for (auto i = 0ull; i < kRows; i++) { - for (auto j = 0ull; j < columns.size(); j++) { - if (j == 0) { - EXPECT_EQ(iterator[i * 2 + j], impl->Cuts().SearchBin(d_data_0[i], j)); - } else { - EXPECT_EQ(iterator[i * 2 + j], impl->Cuts().SearchBin(d_data_1[i], j)); - } - } - } - EXPECT_EQ(dmat.Info().num_col_, 2); - EXPECT_EQ(dmat.Info().num_row_, kRows); - EXPECT_EQ(dmat.Info().num_nonzero_, kRows*2); - -} - -// Test equivalence with simple DMatrix -TEST(DeviceDMatrix, Equivalent) { - int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; - int num_columns = 5; - for (auto num_rows : sizes) { - auto x = common::GenerateRandom(num_rows, num_columns); - for (auto num_bins : bin_sizes) { - auto dmat = common::GetDMatrixFromData(x, num_rows, num_columns); - auto x_device = thrust::device_vector(x); - auto adapter = common::AdapterFromData(x_device, num_rows, num_columns); - data::DeviceDMatrix device_dmat( - &adapter, std::numeric_limits::quiet_NaN(), 1, num_bins); - - const auto &batch = *dmat->GetBatches({0, num_bins}).begin(); - const auto &device_dmat_batch = - *device_dmat.GetBatches({0, num_bins}).begin(); - - ASSERT_EQ(batch.Impl()->Cuts().Values(), device_dmat_batch.Impl()->Cuts().Values()); - ASSERT_EQ(batch.Impl()->gidx_buffer.HostVector(), - device_dmat_batch.Impl()->gidx_buffer.HostVector()); - } - } -} - -TEST(DeviceDMatrix, IsDense) { - int num_bins = 16; - auto test = [num_bins] (float sparsity) { - HostDeviceVector data; - std::string interface_str = RandomDataGenerator{10, 10, sparsity} - .Device(0).GenerateArrayInterface(&data); - data::CupyAdapter x{interface_str}; - std::unique_ptr device_dmat{ new data::DeviceDMatrix( - &x, std::numeric_limits::quiet_NaN(), 1, num_bins) }; - if (sparsity == 0.0) { - ASSERT_TRUE(device_dmat->IsDense()) << sparsity; - } else { - ASSERT_FALSE(device_dmat->IsDense()); - } - }; - test(0.0); - test(0.1); -} diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 77c5a1634..7cff50957 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -54,6 +54,7 @@ void TestTrainingPrediction(size_t rows, size_t bins, learner->SetParam("objective", "multi:softprob"); learner->SetParam("num_feature", std::to_string(kCols)); learner->SetParam("num_class", std::to_string(kClasses)); + learner->SetParam("max_bin", std::to_string(bins)); learner->Configure(); for (size_t i = 0; i < kIters; ++i) { diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 9a2788e2d..0ba5931c9 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -170,3 +170,83 @@ Arrow specification.''' @pytest.mark.skipif(**tm.no_cudf()) def test_cudf_metainfo_device_dmatrix(self): _test_cudf_metainfo(xgb.DeviceQuantileDMatrix) + + +class IterForDMatrixTest(xgb.core.DataIter): + '''A data iterator for XGBoost DMatrix. + + `reset` and `next` are required for any data iterator, other functions here + are utilites for demonstration's purpose. + + ''' + ROWS_PER_BATCH = 100 # data is splited by rows + BATCHES = 16 + + def __init__(self): + '''Generate some random data for demostration. + + Actual data can be anything that is currently supported by XGBoost. + ''' + import cudf + self.rows = self.ROWS_PER_BATCH + rng = np.random.RandomState(1994) + self._data = [ + cudf.DataFrame( + {'a': rng.randn(self.ROWS_PER_BATCH), + 'b': rng.randn(self.ROWS_PER_BATCH)})] * self.BATCHES + self._labels = [rng.randn(self.rows)] * self.BATCHES + + self.it = 0 # set iterator to 0 + super().__init__() + + def as_array(self): + import cudf + return cudf.concat(self._data) + + def as_array_labels(self): + return np.concatenate(self._labels) + + def data(self): + '''Utility function for obtaining current batch of data.''' + return self._data[self.it] + + def labels(self): + '''Utility function for obtaining current batch of label.''' + return self._labels[self.it] + + def reset(self): + '''Reset the iterator''' + self.it = 0 + + def next(self, input_data): + '''Yield next batch of data''' + if self.it == len(self._data): + # Return 0 when there's no more batch. + return 0 + input_data(data=self.data(), label=self.labels()) + self.it += 1 + return 1 + + +@pytest.mark.skipif(**tm.no_cudf()) +def test_from_cudf_iter(): + rounds = 100 + it = IterForDMatrixTest() + + # Use iterator + m_it = xgb.DeviceQuantileDMatrix(it) + reg_with_it = xgb.train({'tree_method': 'gpu_hist'}, m_it, + num_boost_round=rounds) + predict_with_it = reg_with_it.predict(m_it) + + # Without using iterator + m = xgb.DMatrix(it.as_array(), it.as_array_labels()) + + assert m_it.num_col() == m.num_col() + assert m_it.num_row() == m.num_row() + + reg = xgb.train({'tree_method': 'gpu_hist'}, m, + num_boost_round=rounds) + predict = reg.predict(m) + + np.testing.assert_allclose(predict_with_it, predict) diff --git a/tests/python-gpu/test_gpu_demos.py b/tests/python-gpu/test_gpu_demos.py new file mode 100644 index 000000000..03c4a6279 --- /dev/null +++ b/tests/python-gpu/test_gpu_demos.py @@ -0,0 +1,11 @@ +import os +import subprocess +import sys +sys.path.append("tests/python") +import test_demos as td # noqa + + +def test_data_iterator(): + script = os.path.join(td.PYTHON_DEMO_DIR, 'data_iterator.py') + cmd = ['python', script] + subprocess.check_call(cmd)