Accept iterator in device dmatrix. (#5783)

* Remove Device DMatrix.
This commit is contained in:
Jiaming Yuan 2020-07-07 21:44:48 +08:00 committed by GitHub
parent 048d969be4
commit a3ec964346
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 495 additions and 343 deletions

View File

@ -0,0 +1,109 @@
'''A demo for defining data iterator.
The demo that defines a customized iterator for passing batches of data into
`xgboost.DeviceQuantileDMatrix` and use this `DeviceQuantileDMatrix` for
training. The feature is used primarily designed to reduce the required GPU
memory for training on distributed environment.
Aftering going through the demo, one might ask why don't we use more native
Python iterator? That's because XGBoost requires a `reset` function, while
using `itertools.tee` might incur significant memory usage according to:
https://docs.python.org/3/library/itertools.html#itertools.tee.
'''
import xgboost
import cupy
import numpy
COLS = 64
ROWS_PER_BATCH = 1000 # data is splited by rows
BATCHES = 32
class IterForDMatrixDemo(xgboost.core.DataIter):
'''A data iterator for XGBoost DMatrix.
`reset` and `next` are required for any data iterator, other functions here
are utilites for demonstration's purpose.
'''
def __init__(self):
'''Generate some random data for demostration.
Actual data can be anything that is currently supported by XGBoost.
'''
self.rows = ROWS_PER_BATCH
self.cols = COLS
rng = cupy.random.RandomState(1994)
self._data = [rng.randn(self.rows, self.cols)] * BATCHES
self._labels = [rng.randn(self.rows)] * BATCHES
self._weights = [rng.randn(self.rows)] * BATCHES
self.it = 0 # set iterator to 0
super().__init__()
def as_array(self):
return cupy.concatenate(self._data)
def as_array_labels(self):
return cupy.concatenate(self._labels)
def as_array_weights(self):
return cupy.concatenate(self._weights)
def data(self):
'''Utility function for obtaining current batch of data.'''
return self._data[self.it]
def labels(self):
'''Utility function for obtaining current batch of label.'''
return self._labels[self.it]
def weights(self):
return self._weights[self.it]
def reset(self):
'''Reset the iterator'''
self.it = 0
def next(self, input_data):
'''Yield next batch of data.'''
if self.it == len(self._data):
# Return 0 when there's no more batch.
return 0
input_data(data=self.data(), label=self.labels(),
weight=self.weights())
self.it += 1
return 1
def main():
rounds = 100
it = IterForDMatrixDemo()
# Use iterator, must be `DeviceQuantileDMatrix`
m_with_it = xgboost.DeviceQuantileDMatrix(it)
# Use regular DMatrix.
m = xgboost.DMatrix(it.as_array(), it.as_array_labels(),
weight=it.as_array_weights())
assert m_with_it.num_col() == m.num_col()
assert m_with_it.num_row() == m.num_row()
reg_with_it = xgboost.train({'tree_method': 'gpu_hist'}, m_with_it,
num_boost_round=rounds)
predict_with_it = reg_with_it.predict(m_with_it)
reg = xgboost.train({'tree_method': 'gpu_hist'}, m,
num_boost_round=rounds)
predict = reg.predict(m)
numpy.testing.assert_allclose(predict_with_it, predict,
rtol=1e6)
if __name__ == '__main__':
main()

View File

@ -300,6 +300,99 @@ def _cudf_array_interfaces(df):
return interfaces_str
class DataIter:
'''The interface for user defined data iterator. Currently is only
supported by Device DMatrix.
Parameters
----------
rows : int
Total number of rows combining all batches.
cols : int
Number of columns for each batch.
'''
def __init__(self):
proxy_handle = ctypes.c_void_p()
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(proxy_handle)))
self._handle = DeviceQuantileDMatrix(proxy_handle)
self.exception = None
@property
def proxy(self):
'''Handler of DMatrix proxy.'''
return self._handle
def reset_wrapper(self, this): # pylint: disable=unused-argument
'''A wrapper for user defined `reset` function.'''
self.reset()
def next_wrapper(self, this): # pylint: disable=unused-argument
'''A wrapper for user defined `next` function.
`this` is not used in Python. ctypes can handle `self` of a Python
member function automatically when converting a it to c function
pointer.
'''
if self.exception is not None:
return 0
def data_handle(data, label=None, weight=None, base_margin=None,
group=None,
label_lower_bound=None, label_upper_bound=None):
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
# pylint: disable=protected-access
self.proxy._set_data_from_cuda_columnar(data)
elif lazy_isinstance(data, 'cudf.core.series', 'Series'):
# pylint: disable=protected-access
self.proxy._set_data_from_cuda_columnar(data)
elif lazy_isinstance(data, 'cupy.core.core', 'ndarray'):
# pylint: disable=protected-access
self.proxy._set_data_from_cuda_interface(data)
else:
raise TypeError(
'Value type is not supported for data iterator:' +
str(type(self._handle)), type(data))
self.proxy.set_info(label=label, weight=weight,
base_margin=base_margin,
group=group,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
try:
# Deffer the exception in order to return 0 and stop the iteration.
# Exception inside a ctype callback function has no effect except
# for printing to stderr (doesn't stop the execution).
ret = self.next(data_handle) # pylint: disable=not-callable
except Exception as e: # pylint: disable=broad-except
tb = sys.exc_info()[2]
print('Got an exception in Python')
self.exception = e.with_traceback(tb)
return 0
return ret
def reset(self):
'''Reset the data iterator. Prototype for user defined function.'''
raise NotImplementedError()
def next(self, input_data):
'''Set the next batch of data.
Parameters
----------
data_handle: callable
A function with same data fields like `data`, `label` with
`xgboost.DMatrix`.
Returns
-------
0 if there's no more batch, otherwise 1.
'''
raise NotImplementedError()
class DMatrix: # pylint: disable=too-many-instance-attributes
"""Data Matrix used in XGBoost.
@ -361,36 +454,65 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.handle = None
return
handler = self.get_data_handler(data)
handler = self._get_data_handler(data)
can_handle_meta = False
if handler is None:
data = _convert_unknown_data(data, None)
handler = self.get_data_handler(data)
handler = self._get_data_handler(data)
try:
handler.handle_meta(label, weight, base_margin)
can_handle_meta = True
except NotImplementedError:
can_handle_meta = False
self.handle, feature_names, feature_types = handler.handle_input(
data, feature_names, feature_types)
assert self.handle, 'Failed to construct a DMatrix.'
if label is not None:
self.set_label(label)
if weight is not None:
self.set_weight(weight)
if base_margin is not None:
self.set_base_margin(base_margin)
if not can_handle_meta:
self.set_info(label, weight, base_margin)
self.feature_names = feature_names
self.feature_types = feature_types
def get_data_handler(self, data, meta=None, meta_type=None):
def _get_data_handler(self, data, meta=None, meta_type=None):
'''Get data handler for this DMatrix class.'''
from .data import get_dmatrix_data_handler
handler = get_dmatrix_data_handler(
data, self.missing, self.nthread, self.silent, meta, meta_type)
return handler
# pylint: disable=no-self-use
def _get_meta_handler(self, data, meta, meta_type):
from .data import get_dmatrix_meta_handler
handler = get_dmatrix_meta_handler(
data, meta, meta_type)
return handler
def __del__(self):
if hasattr(self, "handle") and self.handle:
_check_call(_LIB.XGDMatrixFree(self.handle))
self.handle = None
def set_info(self,
label=None, weight=None, base_margin=None,
group=None,
label_lower_bound=None,
label_upper_bound=None):
'''Set meta info for DMatrix.'''
if label is not None:
self.set_label(label)
if weight is not None:
self.set_weight(weight)
if base_margin is not None:
self.set_base_margin(base_margin)
if group is not None:
self.set_group(group)
if label_lower_bound is not None:
self.set_float_info('label_lower_bound', label_lower_bound)
if label_upper_bound is not None:
self.set_float_info('label_upper_bound', label_upper_bound)
def get_float_info(self, field):
"""Get float property from the DMatrix.
@ -447,10 +569,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if isinstance(data, np.ndarray):
self.set_float_info_npy2d(field, data)
return
handler = self.get_data_handler(data, field, np.float32)
if handler is None:
data = _convert_unknown_data(data, field, np.float32)
handler = self.get_data_handler(data, field, np.float32)
handler = self._get_data_handler(data, field, np.float32)
assert handler
data, _, _ = handler.transform(data)
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
@ -470,7 +590,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
data: numpy array
The array of data to be set
"""
data, _, _ = self.get_data_handler(
data, _, _ = self._get_meta_handler(
data, field, np.float32).transform(data)
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
@ -489,7 +609,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
data: numpy array
The array of data to be set
"""
data, _, _ = self.get_data_handler(
data, _, _ = self._get_data_handler(
data, field, 'uint32').transform(data)
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
c_str(field),
@ -803,11 +923,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
class DeviceQuantileDMatrix(DMatrix):
"""Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do not
use this for test/validation tasks as some information may be lost in quantisation. This
DMatrix is primarily designed to save memory in training from device memory inputs by
avoiding intermediate storage. Implementation does not currently consider weights in
quantisation process(unlike DMatrix). Set max_bin to control the number of bins during
"""Device memory Data Matrix used in XGBoost for training with
tree_method='gpu_hist'. Do not use this for test/validation tasks as some
information may be lost in quantisation. This DMatrix is primarily designed
to save memory in training from device memory inputs by avoiding
intermediate storage. Set max_bin to control the number of bins during
quantisation.
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
@ -823,6 +943,9 @@ class DeviceQuantileDMatrix(DMatrix):
feature_types=None,
nthread=None, max_bin=256):
self.max_bin = max_bin
if isinstance(data, ctypes.c_void_p):
self.handle = data
return
super().__init__(data, label=label, weight=weight,
base_margin=base_margin,
missing=missing,
@ -831,11 +954,32 @@ class DeviceQuantileDMatrix(DMatrix):
feature_types=feature_types,
nthread=nthread)
def get_data_handler(self, data, meta=None, meta_type=None):
def _get_data_handler(self, data, meta=None, meta_type=None):
from .data import get_device_quantile_dmatrix_data_handler
return get_device_quantile_dmatrix_data_handler(
data, self.max_bin, self.missing, self.nthread, self.silent)
def _set_data_from_cuda_interface(self, data):
'''Set data from CUDA array interface.'''
interface = data.__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
_check_call(
_LIB.XGDeviceQuantileDMatrixSetDataCudaArrayInterface(
self.handle,
interface_str
)
)
def _set_data_from_cuda_columnar(self, data):
'''Set data from CUDA columnar format.1'''
interfaces_str = _cudf_array_interfaces(data)
_check_call(
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
self.handle,
interfaces_str
)
)
class Booster(object):
# pylint: disable=too-many-public-methods

View File

@ -1,4 +1,4 @@
# pylint: disable=too-many-arguments, no-self-use
# pylint: disable=too-many-arguments, no-self-use, too-many-instance-attributes
'''Data dispatching for DMatrix.'''
import ctypes
import abc
@ -8,6 +8,7 @@ import warnings
import numpy as np
from .core import c_array, _LIB, _check_call, c_str, _cudf_array_interfaces
from .core import DataIter
from .compat import lazy_isinstance, STRING_TYPES, os_fspath, os_PathLike
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
@ -23,6 +24,18 @@ class DataHandler(abc.ABC):
self.meta = meta
self.meta_type = meta_type
def handle_meta(self, label=None, weight=None, base_margin=None,
group=None,
label_lower_bound=None,
label_upper_bound=None):
'''Handle meta data when the DMatrix type can not defer setting meta
data after construction. Example is `DeviceQuantileDMatrix`
which requires weight to be presented before digesting
data.
'''
raise NotImplementedError()
def _warn_unused_missing(self, data):
if not (np.isnan(np.nan) or None):
warnings.warn(
@ -116,6 +129,14 @@ def get_dmatrix_data_handler(data, missing, nthread, silent,
return handler(missing, nthread, silent, meta, meta_type)
def get_dmatrix_meta_handler(data, meta, meta_type):
'''Get handler for meta instead of data.'''
handler = __dmatrix_registry.get_handler(data)
if handler is None:
return None
return handler(None, 0, True, meta, meta_type)
class FileHandler(DataHandler):
'''Handler of path like input.'''
def handle_input(self, data, feature_names, feature_types):
@ -511,6 +532,43 @@ __dmatrix_registry.register_handler_opaque(
DLPackHandler)
class SingleBatchInternalIter(DataIter):
'''An iterator for single batch data to help creating device DMatrix.
Transforming input directly to histogram with normal single batch data API
can not access weight for sketching. So this iterator acts as a staging
area for meta info.
'''
def __init__(self, data, label, weight, base_margin, group,
label_lower_bound, label_upper_bound):
self.data = data
self.label = label
self.weight = weight
self.base_margin = base_margin
self.group = group
self.label_lower_bound = label_lower_bound
self.label_upper_bound = label_upper_bound
self.it = 0 # pylint: disable=invalid-name
super().__init__()
def next(self, input_data):
if self.it == 1:
return 0
self.it += 1
input_data(data=self.data, label=self.label,
weight=self.weight, base_margin=self.base_margin,
group=self.group,
label_lower_bound=self.label_lower_bound,
label_upper_bound=self.label_upper_bound)
return 1
def reset(self):
self.it = 0
__device_quantile_dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name
class DeviceQuantileDMatrixDataHandler(DataHandler): # pylint: disable=abstract-method
'''Base class of data handler for `DeviceQuantileDMatrix`.'''
def __init__(self, max_bin, missing, nthread, silent,
@ -518,8 +576,53 @@ class DeviceQuantileDMatrixDataHandler(DataHandler): # pylint: disable=abstract
self.max_bin = max_bin
super().__init__(missing, nthread, silent, meta, meta_type)
def handle_meta(self, label=None, weight=None, base_margin=None,
group=None,
label_lower_bound=None,
label_upper_bound=None):
self.label = label
self.weight = weight
self.base_margin = base_margin
self.group = group
self.label_lower_bound = label_lower_bound
self.label_upper_bound = label_upper_bound
__device_quantile_dmatrix_registry = DMatrixDataManager() # pylint: disable=invalid-name
def handle_input(self, data, feature_names, feature_types):
if not isinstance(data, DataIter):
it = SingleBatchInternalIter(
data, self.label, self.weight,
self.base_margin, self.group,
self.label_lower_bound, self.label_upper_bound)
else:
it = data
reset_factory = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
reset_callback = reset_factory(it.reset_wrapper)
next_factory = ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.c_void_p,
)
next_callback = next_factory(it.next_wrapper)
handle = ctypes.c_void_p()
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(
None,
it.proxy.handle,
reset_callback,
next_callback,
ctypes.c_float(self.missing),
ctypes.c_int(self.nthread),
ctypes.c_int(self.max_bin),
ctypes.byref(handle)
)
if it.exception:
raise it.exception
# delay check_call to throw intermediate exception first
_check_call(ret)
return handle, feature_names, feature_types
__device_quantile_dmatrix_registry.register_handler_opaque(
lambda x: isinstance(x, DataIter),
DeviceQuantileDMatrixDataHandler)
def get_device_quantile_dmatrix_data_handler(
@ -549,19 +652,7 @@ class DeviceQuantileCudaArrayInterfaceHandler(
data, '__array__'):
import cupy # pylint: disable=import-error
data = cupy.array(data, copy=False)
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(self.missing), ctypes.c_int(self.nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
return handle, feature_names, feature_types
return super().handle_input(data, feature_names, feature_types)
__device_quantile_dmatrix_registry.register_handler(
@ -582,14 +673,7 @@ class DeviceQuantileCudaColumnarHandler(DeviceQuantileDMatrixDataHandler,
"""Initialize Quantile Device DMatrix from columnar memory format."""
data, feature_names, feature_types = self._maybe_cudf_dataframe(
data, feature_names, feature_types)
interfaces_str = _cudf_array_interfaces(data)
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(self.missing), ctypes.c_int(self.nthread),
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
return handle, feature_names, feature_types
return super().handle_input(data, feature_names, feature_types)
__device_quantile_dmatrix_registry.register_handler(

View File

@ -4,7 +4,6 @@
#include "xgboost/learner.h"
#include "c_api_error.h"
#include "../data/device_adapter.cuh"
#include "../data/device_dmatrix.h"
using namespace xgboost; // NOLINT
@ -31,28 +30,6 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
API_END();
}
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
bst_float missing, int nthread, int max_bin,
DMatrixHandle* out) {
API_BEGIN();
std::string json_str{c_json_strs};
data::CudfAdapter adapter(json_str);
*out =
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
API_END();
}
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterface(char const* c_json_strs,
bst_float missing, int nthread, int max_bin,
DMatrixHandle* out) {
API_BEGIN();
std::string json_str{c_json_strs};
data::CupyAdapter adapter(json_str);
*out =
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
API_END();
}
// A hidden API as cache id is not being supported yet.
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
char const* c_json_strs,

View File

@ -201,7 +201,7 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
// Returns maximum row length
template <typename AdapterBatchT>
size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
int device_idx, float missing) {
IsValidFunctor is_valid(missing);
// Count elements per row

View File

@ -1,58 +0,0 @@
/*!
* Copyright 2020 by Contributors
* \file device_dmatrix.cu
* \brief Device-memory version of DMatrix.
*/
#include <thrust/execution_policy.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <memory>
#include <utility>
#include "../common/hist_util.h"
#include "adapter.h"
#include "device_adapter.cuh"
#include "ellpack_page.cuh"
#include "device_dmatrix.h"
namespace xgboost {
namespace data {
// Does not currently support metainfo as no on-device data source contains this
// Current implementation assumes a single batch. More batches can
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin) {
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));
auto& batch = adapter->Value();
// Work out how many valid entries we have in each row
dh::caching_device_vector<size_t> row_counts(adapter->NumRows() + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(),
row_counts.size());
size_t row_stride =
GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing);
dh::XGBCachingDeviceAllocator<char> alloc;
info_.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc),
row_counts.begin(), row_counts.end());
info_.num_col_ = adapter->NumColumns();
info_.num_row_ = adapter->NumRows();
ellpack_page_.reset(new EllpackPage());
*ellpack_page_->Impl() =
EllpackPageImpl(adapter, missing, this->IsDense(), nthread, max_bin,
row_counts_span, row_stride);
// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
}
#define DEVICE_DMARIX_SPECIALIZATION(__ADAPTER_T) \
template DeviceDMatrix::DeviceDMatrix(__ADAPTER_T* adapter, float missing, \
int nthread, int max_bin);
DEVICE_DMARIX_SPECIALIZATION(CudfAdapter);
DEVICE_DMARIX_SPECIALIZATION(CupyAdapter);
} // namespace data
} // namespace xgboost

View File

@ -1,64 +0,0 @@
/*!
* Copyright 2020 by Contributors
* \file device_dmatrix.h
* \brief Device-memory version of DMatrix.
*/
#ifndef XGBOOST_DATA_DEVICE_DMATRIX_H_
#define XGBOOST_DATA_DEVICE_DMATRIX_H_
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <memory>
#include "adapter.h"
#include "simple_batch_iterator.h"
#include "simple_dmatrix.h"
namespace xgboost {
namespace data {
class DeviceDMatrix : public DMatrix {
public:
template <typename AdapterT>
explicit DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin);
MetaInfo& Info() override { return info_; }
const MetaInfo& Info() const override { return info_; }
bool SingleColBlock() const override { return true; }
bool EllpackExists() const override { return true; }
bool SparsePageExists() const override { return false; }
DMatrix *Slice(common::Span<int32_t const> ridxs) override {
LOG(FATAL) << "Slicing DMatrix is not supported for Device DMatrix.";
return nullptr;
}
private:
BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
}
BatchSet<CSCPage> GetColumnBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
}
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
}
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override {
auto begin_iter = BatchIterator<EllpackPage>(
new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
return BatchSet<EllpackPage>(begin_iter);
}
MetaInfo info_;
// source data pointer.
std::unique_ptr<EllpackPage> ellpack_page_;
};
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_DEVICE_DMATRIX_H_

View File

@ -93,6 +93,11 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
batches++;
}
if (device < 0) { // error or empty
this->page_.reset(new EllpackPage);
return;
}
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
for (auto const& sketch : sketch_containers) {
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
@ -108,14 +113,23 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
this->info_.num_row_ = accumulated_rows;
this->info_.num_nonzero_ = nnz;
// Construct the final ellpack page.
auto init_page = [this, &proxy, &cuts, row_stride, accumulated_rows]() {
if (!page_) {
// Should be put inside the while loop to protect against empty batch. In
// that case device id is invalid.
page_.reset(new EllpackPage);
*(page_->Impl()) = EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(),
row_stride, accumulated_rows);
*(page_->Impl()) =
EllpackPageImpl(proxy->DeviceIdx(), cuts, this->IsDense(), row_stride,
accumulated_rows);
}
};
// Construct the final ellpack page.
size_t offset = 0;
iter.Reset();
size_t n_batches_for_verification = 0;
while (iter.Next()) {
init_page();
auto device = proxy->DeviceIdx();
dh::safe_cuda(cudaSetDevice(device));
auto rows = num_rows();
@ -138,7 +152,10 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
if (batches != 1) {
this->info_.Extend(std::move(proxy->Info()), false);
}
n_batches_for_verification++;
}
CHECK_EQ(batches, n_batches_for_verification)
<< "Different number of batches returned between 2 iterations";
if (batches == 1) {
this->info_ = std::move(proxy->Info());

View File

@ -1,149 +0,0 @@
// Copyright (c) 2019 by Contributors
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include "../../../src/data/adapter.h"
#include "../../../src/data/ellpack_page.cuh"
#include "../../../src/data/device_dmatrix.h"
#include "../helpers.h"
#include <thrust/device_vector.h>
#include "../../../src/data/device_adapter.cuh"
#include "../../../src/gbm/gbtree_model.h"
#include "../common/test_hist_util.h"
#include "../../../src/common/compressed_iterator.h"
#include "../../../src/common/math.h"
#include "test_array_interface.h"
using namespace xgboost; // NOLINT
TEST(DeviceDMatrix, RowMajor) {
int num_rows = 1000;
int num_columns = 50;
auto x = common::GenerateRandom(num_rows, num_columns);
auto x_device = thrust::device_vector<float>(x);
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
data::DeviceDMatrix dmat(&adapter,
std::numeric_limits<float>::quiet_NaN(), 1, 256);
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
auto impl = batch.Impl();
common::CompressedIterator<uint32_t> iterator(
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
for(auto i = 0ull; i < x.size(); i++)
{
int column_idx = i % num_columns;
EXPECT_EQ(impl->Cuts().SearchBin(x[i], column_idx), iterator[i]);
}
EXPECT_EQ(dmat.Info().num_col_, num_columns);
EXPECT_EQ(dmat.Info().num_row_, num_rows);
EXPECT_EQ(dmat.Info().num_nonzero_, num_rows * num_columns);
}
TEST(DeviceDMatrix, RowMajorMissing) {
const float kMissing = std::numeric_limits<float>::quiet_NaN();
int num_rows = 10;
int num_columns = 2;
auto x = common::GenerateRandom(num_rows, num_columns);
x[1] = kMissing;
x[5] = kMissing;
x[6] = kMissing;
auto x_device = thrust::device_vector<float>(x);
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
data::DeviceDMatrix dmat(&adapter, kMissing, 1, 256);
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
auto impl = batch.Impl();
common::CompressedIterator<uint32_t> iterator(
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
EXPECT_EQ(iterator[1], impl->GetDeviceAccessor(0).NullValue());
EXPECT_EQ(iterator[5], impl->GetDeviceAccessor(0).NullValue());
// null values get placed after valid values in a row
EXPECT_EQ(iterator[7], impl->GetDeviceAccessor(0).NullValue());
EXPECT_EQ(dmat.Info().num_col_, num_columns);
EXPECT_EQ(dmat.Info().num_row_, num_rows);
EXPECT_EQ(dmat.Info().num_nonzero_, num_rows*num_columns-3);
}
TEST(DeviceDMatrix, ColumnMajor) {
constexpr size_t kRows{100};
std::vector<Json> columns;
thrust::device_vector<double> d_data_0(kRows);
thrust::device_vector<uint32_t> d_data_1(kRows);
columns.emplace_back(GenerateDenseColumn<double>("<f8", kRows, &d_data_0));
columns.emplace_back(GenerateDenseColumn<uint32_t>("<u4", kRows, &d_data_1));
Json column_arr{columns};
std::string str;
Json::Dump(column_arr, &str);
data::CudfAdapter adapter(str);
data::DeviceDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
-1, 256);
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
auto impl = batch.Impl();
common::CompressedIterator<uint32_t> iterator(
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
for (auto i = 0ull; i < kRows; i++) {
for (auto j = 0ull; j < columns.size(); j++) {
if (j == 0) {
EXPECT_EQ(iterator[i * 2 + j], impl->Cuts().SearchBin(d_data_0[i], j));
} else {
EXPECT_EQ(iterator[i * 2 + j], impl->Cuts().SearchBin(d_data_1[i], j));
}
}
}
EXPECT_EQ(dmat.Info().num_col_, 2);
EXPECT_EQ(dmat.Info().num_row_, kRows);
EXPECT_EQ(dmat.Info().num_nonzero_, kRows*2);
}
// Test equivalence with simple DMatrix
TEST(DeviceDMatrix, Equivalent) {
int bin_sizes[] = {2, 16, 256, 512};
int sizes[] = {100, 1000, 1500};
int num_columns = 5;
for (auto num_rows : sizes) {
auto x = common::GenerateRandom(num_rows, num_columns);
for (auto num_bins : bin_sizes) {
auto dmat = common::GetDMatrixFromData(x, num_rows, num_columns);
auto x_device = thrust::device_vector<float>(x);
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
data::DeviceDMatrix device_dmat(
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, num_bins);
const auto &batch = *dmat->GetBatches<EllpackPage>({0, num_bins}).begin();
const auto &device_dmat_batch =
*device_dmat.GetBatches<EllpackPage>({0, num_bins}).begin();
ASSERT_EQ(batch.Impl()->Cuts().Values(), device_dmat_batch.Impl()->Cuts().Values());
ASSERT_EQ(batch.Impl()->gidx_buffer.HostVector(),
device_dmat_batch.Impl()->gidx_buffer.HostVector());
}
}
}
TEST(DeviceDMatrix, IsDense) {
int num_bins = 16;
auto test = [num_bins] (float sparsity) {
HostDeviceVector<float> data;
std::string interface_str = RandomDataGenerator{10, 10, sparsity}
.Device(0).GenerateArrayInterface(&data);
data::CupyAdapter x{interface_str};
std::unique_ptr<data::DeviceDMatrix> device_dmat{ new data::DeviceDMatrix(
&x, std::numeric_limits<float>::quiet_NaN(), 1, num_bins) };
if (sparsity == 0.0) {
ASSERT_TRUE(device_dmat->IsDense()) << sparsity;
} else {
ASSERT_FALSE(device_dmat->IsDense());
}
};
test(0.0);
test(0.1);
}

View File

@ -54,6 +54,7 @@ void TestTrainingPrediction(size_t rows, size_t bins,
learner->SetParam("objective", "multi:softprob");
learner->SetParam("num_feature", std::to_string(kCols));
learner->SetParam("num_class", std::to_string(kClasses));
learner->SetParam("max_bin", std::to_string(bins));
learner->Configure();
for (size_t i = 0; i < kIters; ++i) {

View File

@ -170,3 +170,83 @@ Arrow specification.'''
@pytest.mark.skipif(**tm.no_cudf())
def test_cudf_metainfo_device_dmatrix(self):
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
class IterForDMatrixTest(xgb.core.DataIter):
'''A data iterator for XGBoost DMatrix.
`reset` and `next` are required for any data iterator, other functions here
are utilites for demonstration's purpose.
'''
ROWS_PER_BATCH = 100 # data is splited by rows
BATCHES = 16
def __init__(self):
'''Generate some random data for demostration.
Actual data can be anything that is currently supported by XGBoost.
'''
import cudf
self.rows = self.ROWS_PER_BATCH
rng = np.random.RandomState(1994)
self._data = [
cudf.DataFrame(
{'a': rng.randn(self.ROWS_PER_BATCH),
'b': rng.randn(self.ROWS_PER_BATCH)})] * self.BATCHES
self._labels = [rng.randn(self.rows)] * self.BATCHES
self.it = 0 # set iterator to 0
super().__init__()
def as_array(self):
import cudf
return cudf.concat(self._data)
def as_array_labels(self):
return np.concatenate(self._labels)
def data(self):
'''Utility function for obtaining current batch of data.'''
return self._data[self.it]
def labels(self):
'''Utility function for obtaining current batch of label.'''
return self._labels[self.it]
def reset(self):
'''Reset the iterator'''
self.it = 0
def next(self, input_data):
'''Yield next batch of data'''
if self.it == len(self._data):
# Return 0 when there's no more batch.
return 0
input_data(data=self.data(), label=self.labels())
self.it += 1
return 1
@pytest.mark.skipif(**tm.no_cudf())
def test_from_cudf_iter():
rounds = 100
it = IterForDMatrixTest()
# Use iterator
m_it = xgb.DeviceQuantileDMatrix(it)
reg_with_it = xgb.train({'tree_method': 'gpu_hist'}, m_it,
num_boost_round=rounds)
predict_with_it = reg_with_it.predict(m_it)
# Without using iterator
m = xgb.DMatrix(it.as_array(), it.as_array_labels())
assert m_it.num_col() == m.num_col()
assert m_it.num_row() == m.num_row()
reg = xgb.train({'tree_method': 'gpu_hist'}, m,
num_boost_round=rounds)
predict = reg.predict(m)
np.testing.assert_allclose(predict_with_it, predict)

View File

@ -0,0 +1,11 @@
import os
import subprocess
import sys
sys.path.append("tests/python")
import test_demos as td # noqa
def test_data_iterator():
script = os.path.join(td.PYTHON_DEMO_DIR, 'data_iterator.py')
cmd = ['python', script]
subprocess.check_call(cmd)