Device dmatrix (#5420)
This commit is contained in:
parent
780de49ddb
commit
13b10a6370
@ -8,12 +8,13 @@ import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from .core import DMatrix, Booster
|
||||
from .core import DMatrix, DeviceQuantileDMatrix, Booster
|
||||
from .training import train, cv
|
||||
from . import rabit # noqa
|
||||
from . import rabit # noqa
|
||||
from . import tracker # noqa
|
||||
from .tracker import RabitTracker # noqa
|
||||
from . import dask
|
||||
|
||||
try:
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressor, XGBRanker
|
||||
from .sklearn import XGBRFClassifier, XGBRFRegressor
|
||||
@ -31,7 +32,7 @@ VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
|
||||
with open(VERSION_FILE) as f:
|
||||
__version__ = f.read().strip()
|
||||
|
||||
__all__ = ['DMatrix', 'Booster',
|
||||
__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster',
|
||||
'train', 'cv',
|
||||
'RabitTracker',
|
||||
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
|
||||
|
||||
@ -291,6 +291,18 @@ def _maybe_pandas_data(data, feature_names, feature_types,
|
||||
return data, feature_names, feature_types
|
||||
|
||||
|
||||
def _cudf_array_interfaces(df):
|
||||
'''Extract CuDF __cuda_array_interface__'''
|
||||
interfaces = []
|
||||
for col in df:
|
||||
interface = df[col].__cuda_array_interface__
|
||||
if 'mask' in interface:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
interfaces.append(interface)
|
||||
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
|
||||
return interfaces_str
|
||||
|
||||
|
||||
def _maybe_cudf_dataframe(data, feature_names, feature_types):
|
||||
"""Extract internal data from cudf.DataFrame for DMatrix data."""
|
||||
if not (CUDF_INSTALLED and isinstance(data,
|
||||
@ -596,16 +608,10 @@ class DMatrix(object):
|
||||
|
||||
def _init_from_array_interface_columns(self, df, missing, nthread):
|
||||
"""Initialize DMatrix from columnar memory format."""
|
||||
interfaces = []
|
||||
for col in df:
|
||||
interface = df[col].__cuda_array_interface__
|
||||
if 'mask' in interface:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
interfaces.append(interface)
|
||||
interfaces_str = _cudf_array_interfaces(df)
|
||||
handle = ctypes.c_void_p()
|
||||
missing = missing if missing is not None else np.nan
|
||||
nthread = nthread if nthread is not None else 1
|
||||
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
interfaces_str,
|
||||
@ -1005,6 +1011,65 @@ class DMatrix(object):
|
||||
self._feature_types = feature_types
|
||||
|
||||
|
||||
class DeviceQuantileDMatrix(DMatrix):
|
||||
"""Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do not
|
||||
use this for test/validation tasks as some information may be lost in quantisation. This
|
||||
DMatrix is primarily designed to save memory in training and avoids intermediate steps,
|
||||
directly creating a compressed representation for training without allocating additional
|
||||
memory. Implementation does not currently consider weights in quantisation process(unlike
|
||||
DMatrix).
|
||||
|
||||
You can construct DeviceDMatrix from cupy/cudf
|
||||
"""
|
||||
|
||||
def __init__(self, data, label=None, weight=None, base_margin=None,
|
||||
missing=None,
|
||||
silent=False,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
nthread=None, max_bin=256):
|
||||
self.max_bin = max_bin
|
||||
if not (hasattr(data, "__cuda_array_interface__") or (
|
||||
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))):
|
||||
raise ValueError('Only cupy/cudf currently supported for DeviceDMatrix')
|
||||
|
||||
super().__init__(data, label=label, weight=weight, base_margin=base_margin,
|
||||
missing=missing,
|
||||
silent=silent,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
nthread=nthread)
|
||||
|
||||
def _init_from_array_interface_columns(self, df, missing, nthread):
|
||||
"""Initialize DMatrix from columnar memory format."""
|
||||
interfaces_str = _cudf_array_interfaces(df)
|
||||
handle = ctypes.c_void_p()
|
||||
missing = missing if missing is not None else np.nan
|
||||
nthread = nthread if nthread is not None else 1
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(
|
||||
interfaces_str,
|
||||
ctypes.c_float(missing), ctypes.c_int(nthread),
|
||||
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
def _init_from_array_interface(self, data, missing, nthread):
|
||||
"""Initialize DMatrix from cupy ndarray."""
|
||||
interface = data.__cuda_array_interface__
|
||||
if 'mask' in interface:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||
|
||||
handle = ctypes.c_void_p()
|
||||
missing = missing if missing is not None else np.nan
|
||||
nthread = nthread if nthread is not None else 1
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixCreateFromArrayInterface(
|
||||
interface_str,
|
||||
ctypes.c_float(missing), ctypes.c_int(nthread),
|
||||
ctypes.c_int(self.max_bin), ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
class Booster(object):
|
||||
# pylint: disable=too-many-public-methods
|
||||
"""A Booster of XGBoost.
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
#include "xgboost/learner.h"
|
||||
#include "c_api_error.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../data/device_dmatrix.h"
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
@ -29,3 +30,25 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
|
||||
bst_float missing, int nthread, int max_bin,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
std::string json_str{c_json_strs};
|
||||
data::CudfAdapter adapter(json_str);
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
bst_float missing, int nthread, int max_bin,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
std::string json_str{c_json_strs};
|
||||
data::CupyAdapter adapter(json_str);
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(new data::DeviceDMatrix(&adapter, missing, nthread, max_bin));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -32,8 +32,8 @@ static const int kPadding = 4; // Assign padding so we can read slightly off
|
||||
// the beginning of the array
|
||||
|
||||
// The number of bits required to represent a given unsigned range
|
||||
static size_t SymbolBits(size_t num_symbols) {
|
||||
auto bits = std::ceil(std::log2(num_symbols));
|
||||
inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) {
|
||||
auto bits = std::ceil(log2(static_cast<double>(num_symbols)));
|
||||
return std::max(static_cast<size_t>(bits), size_t(1));
|
||||
}
|
||||
} // namespace detail
|
||||
@ -50,14 +50,11 @@ static size_t SymbolBits(size_t num_symbols) {
|
||||
*/
|
||||
|
||||
class CompressedBufferWriter {
|
||||
private:
|
||||
size_t symbol_bits_;
|
||||
size_t offset_;
|
||||
|
||||
public:
|
||||
explicit CompressedBufferWriter(size_t num_symbols) : offset_(0) {
|
||||
symbol_bits_ = detail::SymbolBits(num_symbols);
|
||||
}
|
||||
XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols)
|
||||
: symbol_bits_(detail::SymbolBits(num_symbols)) {}
|
||||
|
||||
/**
|
||||
* \fn static size_t CompressedBufferWriter::CalculateBufferSize(int
|
||||
@ -164,18 +161,15 @@ class CompressedBufferWriter {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
/**
|
||||
* \class CompressedIterator
|
||||
*
|
||||
* \brief Read symbols from a bit compressed memory buffer. Usable on device and
|
||||
* host.
|
||||
* \brief Read symbols from a bit compressed memory buffer. Usable on device and host.
|
||||
*
|
||||
* \author Rory
|
||||
* \date 7/9/2017
|
||||
*
|
||||
* \tparam T Generic type parameter.
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
class CompressedIterator {
|
||||
public:
|
||||
// Type definitions for thrust
|
||||
|
||||
@ -1540,4 +1540,12 @@ DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
|
||||
static_cast<typename OutputGradientT::ValueT>(gpair.GetHess()));
|
||||
}
|
||||
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
||||
IterT iter, FuncT func) {
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
} // namespace dh
|
||||
|
||||
@ -338,31 +338,6 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
return cuts;
|
||||
}
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const data::COOTuple& e) const {
|
||||
if (common::CheckNAN(e.value) || e.value == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
__device__ bool operator()(const Entry& e) const {
|
||||
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// Thrust version of this function causes error on Windows
|
||||
template <typename ReturnT, typename IterT, typename FuncT>
|
||||
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
|
||||
IterT iter, FuncT func) {
|
||||
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
|
||||
}
|
||||
|
||||
template <typename AdapterT>
|
||||
void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
SketchContainer* sketch_container, int num_cuts) {
|
||||
@ -372,10 +347,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
auto &batch = adapter->Value();
|
||||
// Enforce single batch
|
||||
CHECK(!adapter->Next());
|
||||
auto batch_iter = MakeTransformIterator<data::COOTuple>(
|
||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||
thrust::make_counting_iterator(0llu),
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||
auto entry_iter = MakeTransformIterator<Entry>(
|
||||
auto entry_iter = dh::MakeTransformIterator<Entry>(
|
||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||
return Entry(batch.GetElement(idx).column_idx,
|
||||
batch.GetElement(idx).value);
|
||||
@ -385,7 +360,7 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
|
||||
0);
|
||||
|
||||
auto d_column_sizes_scan = column_sizes_scan.data().get();
|
||||
IsValidFunctor is_valid(missing);
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
auto e = batch_iter[begin + idx];
|
||||
if (is_valid(e)) {
|
||||
|
||||
@ -105,10 +105,10 @@ class HistogramCuts {
|
||||
auto end = cut_ptrs_.ConstHostVector().at(column_id + 1);
|
||||
const auto &values = cut_values_.ConstHostVector();
|
||||
auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value);
|
||||
if (it == values.cend()) {
|
||||
it = values.cend() - 1;
|
||||
}
|
||||
BinIdx idx = it - values.cbegin();
|
||||
if (idx == end) {
|
||||
idx -= 1;
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
#include "array_interface.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "device_adapter.cuh"
|
||||
#include "simple_dmatrix.h"
|
||||
#include "device_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
|
||||
@ -8,12 +8,31 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/math.h"
|
||||
#include "adapter.h"
|
||||
#include "array_interface.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const data::COOTuple& e) const {
|
||||
if (common::CheckNAN(e.value) || e.value == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
__device__ bool operator()(const Entry& e) const {
|
||||
if (common::CheckNAN(e.fvalue) || e.fvalue == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
CudfAdapterBatch() = default;
|
||||
|
||||
238
src/data/device_dmatrix.cu
Normal file
238
src/data/device_dmatrix.cu
Normal file
@ -0,0 +1,238 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file device_dmatrix.cu
|
||||
* \brief Device-memory version of DMatrix.
|
||||
*/
|
||||
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/transform_output_iterator.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "../common/hist_util.h"
|
||||
#include "adapter.h"
|
||||
#include "device_adapter.cuh"
|
||||
#include "ellpack_page.cuh"
|
||||
#include "device_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
// Returns maximum row length
|
||||
template <typename AdapterBatchT>
|
||||
size_t GetRowCounts(const AdapterBatchT& batch, common::Span<size_t> offset,
|
||||
int device_idx, float missing) {
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
|
||||
auto element = batch.GetElement(idx);
|
||||
if (is_valid(element)) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&offset[element.row_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
}
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
size_t row_stride = thrust::reduce(
|
||||
thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data()) + offset.size(), size_t(0),
|
||||
thrust::maximum<size_t>());
|
||||
return row_stride;
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
struct WriteCompressedEllpackFunctor {
|
||||
WriteCompressedEllpackFunctor(common::CompressedByteT* buffer,
|
||||
const common::CompressedBufferWriter& writer,
|
||||
const AdapterBatchT& batch,
|
||||
EllpackDeviceAccessor accessor,
|
||||
const IsValidFunctor& is_valid)
|
||||
: d_buffer(buffer),
|
||||
writer(writer),
|
||||
batch(batch),
|
||||
accessor(std::move(accessor)),
|
||||
is_valid(is_valid) {}
|
||||
|
||||
common::CompressedByteT* d_buffer;
|
||||
common::CompressedBufferWriter writer;
|
||||
AdapterBatchT batch;
|
||||
EllpackDeviceAccessor accessor;
|
||||
IsValidFunctor is_valid;
|
||||
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
__device__ size_t operator()(Tuple out) {
|
||||
auto e = batch.GetElement(out.get<2>());
|
||||
if (is_valid(e)) {
|
||||
// -1 because the scan is inclusive
|
||||
size_t output_position =
|
||||
accessor.row_stride * e.row_idx + out.get<1>() - 1;
|
||||
auto bin_idx = accessor.SearchBin(e.value, e.column_idx);
|
||||
writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterBatchT>
|
||||
void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst,
|
||||
int device_idx, float missing) {
|
||||
// Some witchcraft happens here
|
||||
// The goal is to copy valid elements out of the input to an ellpack matrix
|
||||
// with a given row stride, using no extra working memory Standard stream
|
||||
// compaction needs to be modified to do this, so we manually define a
|
||||
// segmented stream compaction via operators on an inclusive scan. The output
|
||||
// of this inclusive scan is fed to a custom function which works out the
|
||||
// correct output position
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
IsValidFunctor is_valid(missing);
|
||||
auto key_iter = dh::MakeTransformIterator<size_t>(
|
||||
counting,
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx).row_idx; });
|
||||
auto value_iter = dh::MakeTransformIterator<size_t>(
|
||||
counting, [=] __device__(size_t idx) -> size_t {
|
||||
return is_valid(batch.GetElement(idx));
|
||||
});
|
||||
|
||||
auto key_value_index_iter = thrust::make_zip_iterator(
|
||||
thrust::make_tuple(key_iter, value_iter, counting));
|
||||
|
||||
// Tuple[0] = The row index of the input, used as a key to define segments
|
||||
// Tuple[1] = Scanned flags of valid elements for each row
|
||||
// Tuple[2] = The index in the input data
|
||||
using Tuple = thrust::tuple<size_t, size_t, size_t>;
|
||||
|
||||
auto device_accessor = dst->GetDeviceAccessor(device_idx);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
|
||||
// We redirect the scan output into this functor to do the actual writing
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT> functor(
|
||||
d_compressed_buffer, writer, batch, device_accessor, is_valid);
|
||||
thrust::discard_iterator<size_t> discard;
|
||||
thrust::transform_output_iterator<
|
||||
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
|
||||
out(discard, functor);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter,
|
||||
key_value_index_iter + batch.Size(), out,
|
||||
[=] __device__(Tuple a, Tuple b) {
|
||||
// Key equal
|
||||
if (a.get<0>() == b.get<0>()) {
|
||||
b.get<1>() += a.get<1>();
|
||||
return b;
|
||||
}
|
||||
// Not equal
|
||||
return b;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename AdapterT, typename AdapterBatchT>
|
||||
void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch,
|
||||
EllpackPageImpl* dst, float missing) {
|
||||
// Step 1: Get the sizes of the input columns
|
||||
dh::caching_device_vector<size_t> column_sizes(adapter->NumColumns(), 0);
|
||||
auto d_column_sizes = column_sizes.data().get();
|
||||
// Populate column sizes
|
||||
dh::LaunchN(adapter->DeviceIdx(), batch.Size(), [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&d_column_sizes[e.column_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
});
|
||||
|
||||
thrust::host_vector<size_t> host_column_sizes = column_sizes;
|
||||
|
||||
// Step 2: Iterate over columns, place elements in correct row, increment
|
||||
// temporary row pointers
|
||||
dh::caching_device_vector<size_t> temp_row_ptr(adapter->NumRows(), 0);
|
||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
||||
auto row_stride = dst->row_stride;
|
||||
size_t begin = 0;
|
||||
auto device_accessor = dst->GetDeviceAccessor(adapter->DeviceIdx());
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
IsValidFunctor is_valid(missing);
|
||||
for (auto size : host_column_sizes) {
|
||||
size_t end = begin + size;
|
||||
dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) {
|
||||
auto writer_non_const =
|
||||
writer; // For some reason this variable gets captured as const
|
||||
const auto& e = batch.GetElement(idx + begin);
|
||||
if (!is_valid(e)) return;
|
||||
size_t output_position =
|
||||
e.row_idx * row_stride + d_temp_row_ptr[e.row_idx];
|
||||
auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx);
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx,
|
||||
output_position);
|
||||
d_temp_row_ptr[e.row_idx] += 1;
|
||||
});
|
||||
|
||||
begin = end;
|
||||
}
|
||||
}
|
||||
|
||||
void WriteNullValues(EllpackPageImpl* dst, int device_idx,
|
||||
common::Span<size_t> row_counts) {
|
||||
// Write the null values
|
||||
auto device_accessor = dst->GetDeviceAccessor(device_idx);
|
||||
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
|
||||
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
|
||||
auto row_stride = dst->row_stride;
|
||||
dh::LaunchN(device_idx, row_stride * dst->n_rows, [=] __device__(size_t idx) {
|
||||
auto writer_non_const =
|
||||
writer; // For some reason this variable gets captured as const
|
||||
size_t row_idx = idx / row_stride;
|
||||
size_t row_offset = idx % row_stride;
|
||||
if (row_offset >= row_counts[row_idx]) {
|
||||
writer_non_const.AtomicWriteSymbol(d_compressed_buffer,
|
||||
device_accessor.NullValue(), idx);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Does not currently support metainfo as no on-device data source contains this
|
||||
// Current implementation assumes a single batch. More batches can
|
||||
// be supported in future. Does not currently support inferring row/column size
|
||||
template <typename AdapterT>
|
||||
DeviceDMatrix::DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin) {
|
||||
common::HistogramCuts cuts =
|
||||
common::AdapterDeviceSketch(adapter, max_bin, missing);
|
||||
auto& batch = adapter->Value();
|
||||
// Work out how many valid entries we have in each row
|
||||
dh::caching_device_vector<size_t> row_counts(adapter->NumRows() + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(),
|
||||
row_counts.size());
|
||||
size_t row_stride =
|
||||
GetRowCounts(batch, row_counts_span, adapter->DeviceIdx(), missing);
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
info.num_nonzero_ = thrust::reduce(thrust::cuda::par(alloc),
|
||||
row_counts.begin(), row_counts.end());
|
||||
info.num_col_ = adapter->NumColumns();
|
||||
info.num_row_ = adapter->NumRows();
|
||||
ellpack_page_.reset(new EllpackPage());
|
||||
*ellpack_page_->Impl() =
|
||||
EllpackPageImpl(adapter->DeviceIdx(), cuts, this->IsDense(), row_stride,
|
||||
adapter->NumRows());
|
||||
if (adapter->IsRowMajor()) {
|
||||
CopyDataRowMajor(batch, ellpack_page_->Impl(), adapter->DeviceIdx(),
|
||||
missing);
|
||||
} else {
|
||||
CopyDataColumnMajor(adapter, batch, ellpack_page_->Impl(), missing);
|
||||
}
|
||||
|
||||
WriteNullValues(ellpack_page_->Impl(), adapter->DeviceIdx(), row_counts_span);
|
||||
|
||||
// Synchronise worker columns
|
||||
rabit::Allreduce<rabit::op::Max>(&info.num_col_, 1);
|
||||
}
|
||||
template DeviceDMatrix::DeviceDMatrix(CudfAdapter* adapter, float missing,
|
||||
int nthread, int max_bin);
|
||||
template DeviceDMatrix::DeviceDMatrix(CupyAdapter* adapter, float missing,
|
||||
int nthread, int max_bin);
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
60
src/data/device_dmatrix.h
Normal file
60
src/data/device_dmatrix.h
Normal file
@ -0,0 +1,60 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* \file device_dmatrix.h
|
||||
* \brief Device-memory version of DMatrix.
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
#define XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "adapter.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
#include "simple_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
class DeviceDMatrix : public DMatrix {
|
||||
public:
|
||||
template <typename AdapterT>
|
||||
explicit DeviceDMatrix(AdapterT* adapter, float missing, int nthread, int max_bin);
|
||||
|
||||
MetaInfo& Info() override { return info; }
|
||||
|
||||
const MetaInfo& Info() const override { return info; }
|
||||
|
||||
bool SingleColBlock() const override { return true; }
|
||||
|
||||
bool EllpackExists() const override { return true; }
|
||||
bool SparsePageExists() const override { return false; }
|
||||
|
||||
private:
|
||||
BatchSet<SparsePage> GetRowBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
|
||||
}
|
||||
BatchSet<CSCPage> GetColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<CSCPage>(BatchIterator<CSCPage>(nullptr));
|
||||
}
|
||||
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(nullptr));
|
||||
}
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override {
|
||||
auto begin_iter = BatchIterator<EllpackPage>(
|
||||
new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||
return BatchSet<EllpackPage>(begin_iter);
|
||||
}
|
||||
|
||||
MetaInfo info;
|
||||
// source data pointer.
|
||||
std::unique_ptr<EllpackPage> ellpack_page_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_DEVICE_DMATRIX_H_
|
||||
@ -26,7 +26,6 @@ void EllpackPage::SetBaseRowId(size_t row_id) {
|
||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||
"EllpackPage is required";
|
||||
}
|
||||
|
||||
size_t EllpackPage::Size() const {
|
||||
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but "
|
||||
"EllpackPage is required";
|
||||
|
||||
@ -210,8 +210,8 @@ void EllpackPageImpl::InitCompressedData(int device) {
|
||||
|
||||
// Required buffer size for storing data matrix in ELLPack format.
|
||||
size_t compressed_size_bytes =
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
|
||||
num_symbols);
|
||||
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
|
||||
num_symbols);
|
||||
gidx_buffer.SetDevice(device);
|
||||
// Don't call fill unnecessarily
|
||||
if (gidx_buffer.Size() == 0) {
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include "../common/compressed_iterator.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include <thrust/binary_search.h>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -90,6 +91,19 @@ struct EllpackDeviceAccessor {
|
||||
}
|
||||
return gidx;
|
||||
}
|
||||
|
||||
__device__ uint32_t SearchBin(float value, size_t column_id) const {
|
||||
auto beg = feature_segments[column_id];
|
||||
auto end = feature_segments[column_id + 1];
|
||||
auto it =
|
||||
thrust::upper_bound(thrust::seq, gidx_fvalue_map.cbegin()+ beg, gidx_fvalue_map.cbegin() + end, value);
|
||||
uint32_t idx = it - gidx_fvalue_map.cbegin();
|
||||
if (idx == end) {
|
||||
idx -= 1;
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
__device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
|
||||
auto gidx = GetBinIndex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
@ -104,7 +118,7 @@ struct EllpackDeviceAccessor {
|
||||
}
|
||||
/*! \brief Return the total number of symbols (total number of bins plus 1 for
|
||||
* not found). */
|
||||
size_t NumSymbols() const { return gidx_fvalue_map.size() + 1; }
|
||||
XGBOOST_DEVICE size_t NumSymbols() const { return gidx_fvalue_map.size() + 1; }
|
||||
|
||||
size_t NullValue() const { return gidx_fvalue_map.size(); }
|
||||
|
||||
|
||||
@ -8,26 +8,20 @@
|
||||
#include <xgboost/data.h>
|
||||
#include "../common/random.h"
|
||||
#include "./simple_dmatrix.h"
|
||||
#include "../common/math.h"
|
||||
#include "device_adapter.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
XGBOOST_DEVICE bool IsValid(float value, float missing) {
|
||||
if (common::CheckNAN(value) || value == missing) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
int device_idx, float missing) {
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
|
||||
auto element = batch.GetElement(idx);
|
||||
if (IsValid(element.value, missing)) {
|
||||
if (is_valid(element)) {
|
||||
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
|
||||
&offset[element.row_idx]),
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
@ -66,11 +60,12 @@ void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
thrust::device_pointer_cast(row_ptr.data() + row_ptr.size()));
|
||||
auto d_temp_row_ptr = temp_row_ptr.data().get();
|
||||
size_t begin = 0;
|
||||
IsValidFunctor is_valid(missing);
|
||||
for (auto size : host_column_sizes) {
|
||||
size_t end = begin + size;
|
||||
dh::LaunchN(device_idx, end - begin, [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx + begin);
|
||||
if (!IsValid(e.value, missing)) return;
|
||||
if (!is_valid(e)) return;
|
||||
data[d_temp_row_ptr[e.row_idx]] = Entry(e.column_idx, e.value);
|
||||
d_temp_row_ptr[e.row_idx] += 1;
|
||||
});
|
||||
@ -79,15 +74,6 @@ void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
}
|
||||
}
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const Entry& x) const {
|
||||
return IsValid(x.fvalue, missing);
|
||||
}
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterT>
|
||||
|
||||
@ -400,10 +400,11 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
||||
|
||||
auto on_device =
|
||||
f_dmat &&
|
||||
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
|
||||
(f_dmat->PageExists<EllpackPage>() ||
|
||||
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead());
|
||||
|
||||
// Use GPU Predictor if data is already on device.
|
||||
if (on_device) {
|
||||
// Use GPU Predictor if data is already on device and gpu_id is set.
|
||||
if (on_device && generic_param_->gpu_id >= 0) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
|
||||
@ -44,7 +44,7 @@ case "$suite" in
|
||||
cudf)
|
||||
source activate cudf_test
|
||||
install_xgboost
|
||||
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py tests/python-gpu/test_from_cupy.py
|
||||
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_cudf.py tests/python-gpu/test_from_cupy.py
|
||||
;;
|
||||
|
||||
cpu)
|
||||
|
||||
@ -284,5 +284,28 @@ TEST(hist_util, AdapterDeviceSketchBatches) {
|
||||
ValidateCuts(cuts, dmat.get(), num_bins);
|
||||
}
|
||||
}
|
||||
|
||||
// Check sketching from adapter or DMatrix results in the same answer
|
||||
// Consistency here is useful for testing and user experience
|
||||
TEST(hist_util, SketchingEquivalent) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = GenerateRandom(num_rows, num_columns);
|
||||
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto dmat_cuts = DeviceSketch(0, dmat.get(), num_bins);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = AdapterFromData(x_device, num_rows, num_columns);
|
||||
auto adapter_cuts = AdapterDeviceSketch(
|
||||
&adapter, num_bins, std::numeric_limits<float>::quiet_NaN());
|
||||
EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values());
|
||||
EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs());
|
||||
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
131
tests/cpp/data/test_device_dmatrix.cu
Normal file
131
tests/cpp/data/test_device_dmatrix.cu
Normal file
@ -0,0 +1,131 @@
|
||||
|
||||
// Copyright (c) 2019 by Contributors
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../../../src/data/ellpack_page.cuh"
|
||||
#include "../../../src/data/device_dmatrix.h"
|
||||
#include "../helpers.h"
|
||||
#include <thrust/device_vector.h>
|
||||
#include "../../../src/data/device_adapter.cuh"
|
||||
#include "../../../src/gbm/gbtree_model.h"
|
||||
#include "../common/test_hist_util.h"
|
||||
#include "../../../src/common/compressed_iterator.h"
|
||||
#include "../../../src/common/math.h"
|
||||
#include "test_array_interface.h"
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
TEST(DeviceDMatrix, RowMajor) {
|
||||
int num_rows = 1000;
|
||||
int num_columns = 50;
|
||||
auto x = common::GenerateRandom(num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
|
||||
|
||||
data::DeviceDMatrix dmat(&adapter,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1, 256);
|
||||
|
||||
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
|
||||
auto impl = batch.Impl();
|
||||
common::CompressedIterator<uint32_t> iterator(
|
||||
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
|
||||
for(auto i = 0ull; i < x.size(); i++)
|
||||
{
|
||||
int column_idx = i % num_columns;
|
||||
EXPECT_EQ(impl->cuts_.SearchBin(x[i], column_idx), iterator[i]);
|
||||
}
|
||||
EXPECT_EQ(dmat.Info().num_col_, num_columns);
|
||||
EXPECT_EQ(dmat.Info().num_row_, num_rows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, num_rows * num_columns);
|
||||
|
||||
}
|
||||
|
||||
TEST(DeviceDMatrix, RowMajorMissing) {
|
||||
const float kMissing = std::numeric_limits<float>::quiet_NaN();
|
||||
int num_rows = 10;
|
||||
int num_columns = 2;
|
||||
auto x = common::GenerateRandom(num_rows, num_columns);
|
||||
x[1] = kMissing;
|
||||
x[5] = kMissing;
|
||||
x[6] = kMissing;
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
|
||||
|
||||
data::DeviceDMatrix dmat(&adapter, kMissing, 1, 256);
|
||||
|
||||
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
|
||||
auto impl = batch.Impl();
|
||||
common::CompressedIterator<uint32_t> iterator(
|
||||
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
|
||||
EXPECT_EQ(iterator[1], impl->GetDeviceAccessor(0).NullValue());
|
||||
EXPECT_EQ(iterator[5], impl->GetDeviceAccessor(0).NullValue());
|
||||
// null values get placed after valid values in a row
|
||||
EXPECT_EQ(iterator[7], impl->GetDeviceAccessor(0).NullValue());
|
||||
EXPECT_EQ(dmat.Info().num_col_, num_columns);
|
||||
EXPECT_EQ(dmat.Info().num_row_, num_rows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, num_rows*num_columns-3);
|
||||
|
||||
}
|
||||
|
||||
TEST(DeviceDMatrix, ColumnMajor) {
|
||||
constexpr size_t kRows{100};
|
||||
std::vector<Json> columns;
|
||||
thrust::device_vector<double> d_data_0(kRows);
|
||||
thrust::device_vector<uint32_t> d_data_1(kRows);
|
||||
|
||||
columns.emplace_back(GenerateDenseColumn<double>("<f8", kRows, &d_data_0));
|
||||
columns.emplace_back(GenerateDenseColumn<uint32_t>("<u4", kRows, &d_data_1));
|
||||
|
||||
Json column_arr{columns};
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
|
||||
data::CudfAdapter adapter(str);
|
||||
data::DeviceDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||
-1, 256);
|
||||
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
|
||||
auto impl = batch.Impl();
|
||||
common::CompressedIterator<uint32_t> iterator(
|
||||
impl->gidx_buffer.HostVector().data(), impl->NumSymbols());
|
||||
|
||||
for (auto i = 0ull; i < kRows; i++) {
|
||||
for (auto j = 0ull; j < columns.size(); j++) {
|
||||
if (j == 0) {
|
||||
EXPECT_EQ(iterator[i * 2 + j], impl->cuts_.SearchBin(d_data_0[i], j));
|
||||
} else {
|
||||
EXPECT_EQ(iterator[i * 2 + j], impl->cuts_.SearchBin(d_data_1[i], j));
|
||||
}
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(dmat.Info().num_col_, 2);
|
||||
EXPECT_EQ(dmat.Info().num_row_, kRows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, kRows*2);
|
||||
|
||||
}
|
||||
|
||||
// Test equivalence with simple DMatrix
|
||||
TEST(DeviceDMatrix, Equivalent) {
|
||||
int bin_sizes[] = {2, 16, 256, 512};
|
||||
int sizes[] = {100, 1000, 1500};
|
||||
int num_columns = 5;
|
||||
for (auto num_rows : sizes) {
|
||||
auto x = common::GenerateRandom(num_rows, num_columns);
|
||||
for (auto num_bins : bin_sizes) {
|
||||
auto dmat = common::GetDMatrixFromData(x, num_rows, num_columns);
|
||||
auto x_device = thrust::device_vector<float>(x);
|
||||
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
|
||||
data::DeviceDMatrix device_dmat(
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, num_bins);
|
||||
|
||||
const auto &batch = *dmat->GetBatches<EllpackPage>({0, num_bins}).begin();
|
||||
const auto &device_dmat_batch =
|
||||
*device_dmat.GetBatches<EllpackPage>({0, num_bins}).begin();
|
||||
|
||||
ASSERT_EQ(batch.Impl()->cuts_.Values(), device_dmat_batch.Impl()->cuts_.Values());
|
||||
ASSERT_EQ(batch.Impl()->gidx_buffer.HostVector(),
|
||||
device_dmat_batch.Impl()->gidx_buffer.HostVector());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,148 +0,0 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
|
||||
def dmatrix_from_cudf(input_type, missing=np.NAN):
|
||||
'''Test constructing DMatrix from cudf'''
|
||||
import cudf
|
||||
import pandas as pd
|
||||
|
||||
kRows = 80
|
||||
kCols = 3
|
||||
|
||||
na = np.random.randn(kRows, kCols)
|
||||
na[:, 0:2] = na[:, 0:2].astype(input_type)
|
||||
|
||||
na[5, 0] = missing
|
||||
na[3, 1] = missing
|
||||
|
||||
pa = pd.DataFrame({'0': na[:, 0],
|
||||
'1': na[:, 1],
|
||||
'2': na[:, 2].astype(np.int32)})
|
||||
|
||||
np_label = np.random.randn(kRows).astype(input_type)
|
||||
pa_label = pd.DataFrame(np_label)
|
||||
|
||||
cd = cudf.from_pandas(pa)
|
||||
cd_label = cudf.from_pandas(pa_label).iloc[:, 0]
|
||||
|
||||
dtrain = xgb.DMatrix(cd, missing=missing, label=cd_label)
|
||||
assert dtrain.num_col() == kCols
|
||||
assert dtrain.num_row() == kRows
|
||||
|
||||
|
||||
class TestFromColumnar:
|
||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
||||
Arrow specification.'''
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_from_cudf(self):
|
||||
'''Test constructing DMatrix from cudf'''
|
||||
import cudf
|
||||
dmatrix_from_cudf(np.float32, np.NAN)
|
||||
dmatrix_from_cudf(np.float64, np.NAN)
|
||||
|
||||
dmatrix_from_cudf(np.int8, 2)
|
||||
dmatrix_from_cudf(np.int32, -2)
|
||||
dmatrix_from_cudf(np.int64, -3)
|
||||
|
||||
cd = cudf.DataFrame({'x': [1, 2, 3], 'y': [0.1, 0.2, 0.3]})
|
||||
dtrain = xgb.DMatrix(cd)
|
||||
|
||||
assert dtrain.feature_names == ['x', 'y']
|
||||
assert dtrain.feature_types == ['int', 'float']
|
||||
|
||||
series = cudf.DataFrame({'x': [1, 2, 3]}).iloc[:, 0]
|
||||
assert isinstance(series, cudf.Series)
|
||||
dtrain = xgb.DMatrix(series)
|
||||
|
||||
assert dtrain.feature_names == ['x']
|
||||
assert dtrain.feature_types == ['int']
|
||||
|
||||
with pytest.raises(Exception):
|
||||
dtrain = xgb.DMatrix(cd, label=cd)
|
||||
|
||||
# Test when number of elements is less than 8
|
||||
X = cudf.DataFrame({'x': cudf.Series([0, 1, 2, np.NAN, 4],
|
||||
dtype=np.int32)})
|
||||
dtrain = xgb.DMatrix(X)
|
||||
assert dtrain.num_col() == 1
|
||||
assert dtrain.num_row() == 5
|
||||
|
||||
# Boolean is not supported.
|
||||
X_boolean = cudf.DataFrame({'x': cudf.Series([True, False])})
|
||||
with pytest.raises(Exception):
|
||||
dtrain = xgb.DMatrix(X_boolean)
|
||||
|
||||
y_boolean = cudf.DataFrame({
|
||||
'x': cudf.Series([True, False, True, True, True])})
|
||||
with pytest.raises(Exception):
|
||||
dtrain = xgb.DMatrix(X_boolean, label=y_boolean)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_cudf_training(self):
|
||||
from cudf import DataFrame as df
|
||||
import pandas as pd
|
||||
np.random.seed(1)
|
||||
X = pd.DataFrame(np.random.randn(50, 10))
|
||||
y = pd.DataFrame(np.random.randn(50))
|
||||
weights = np.random.random(50) + 1.0
|
||||
cudf_weights = df.from_pandas(pd.DataFrame(weights))
|
||||
base_margin = np.random.random(50)
|
||||
cudf_base_margin = df.from_pandas(pd.DataFrame(base_margin))
|
||||
|
||||
evals_result_cudf = {}
|
||||
dtrain_cudf = xgb.DMatrix(df.from_pandas(X), df.from_pandas(y), weight=cudf_weights,
|
||||
base_margin=cudf_base_margin)
|
||||
params = {'gpu_id': 0}
|
||||
xgb.train(params, dtrain_cudf, evals=[(dtrain_cudf, "train")],
|
||||
evals_result=evals_result_cudf)
|
||||
evals_result_np = {}
|
||||
dtrain_np = xgb.DMatrix(X, y, weight=weights, base_margin=base_margin)
|
||||
xgb.train(params, dtrain_np, evals=[(dtrain_np, "train")],
|
||||
evals_result=evals_result_np)
|
||||
assert np.array_equal(evals_result_cudf["train"]["rmse"], evals_result_np["train"]["rmse"])
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_cudf_metainfo(self):
|
||||
from cudf import DataFrame as df
|
||||
import pandas as pd
|
||||
n = 100
|
||||
X = np.random.random((n, 2))
|
||||
dmat_cudf = xgb.DMatrix(X)
|
||||
dmat = xgb.DMatrix(X)
|
||||
floats = np.random.random(n)
|
||||
uints = np.array([4, 2, 8]).astype("uint32")
|
||||
cudf_floats = df.from_pandas(pd.DataFrame(floats))
|
||||
cudf_uints = df.from_pandas(pd.DataFrame(uints))
|
||||
dmat.set_float_info('weight', floats)
|
||||
dmat.set_float_info('label', floats)
|
||||
dmat.set_float_info('base_margin', floats)
|
||||
dmat.set_uint_info('group', uints)
|
||||
dmat_cudf.set_interface_info('weight', cudf_floats)
|
||||
dmat_cudf.set_interface_info('label', cudf_floats)
|
||||
dmat_cudf.set_interface_info('base_margin', cudf_floats)
|
||||
dmat_cudf.set_interface_info('group', cudf_uints)
|
||||
|
||||
# Test setting info with cudf DataFrame
|
||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||
dmat_cudf.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
||||
|
||||
# Test setting info with cudf Series
|
||||
dmat_cudf.set_interface_info('weight', cudf_floats[cudf_floats.columns[0]])
|
||||
dmat_cudf.set_interface_info('label', cudf_floats[cudf_floats.columns[0]])
|
||||
dmat_cudf.set_interface_info('base_margin', cudf_floats[cudf_floats.columns[0]])
|
||||
dmat_cudf.set_interface_info('group', cudf_uints[cudf_uints.columns[0]])
|
||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||
dmat_cudf.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
||||
172
tests/python-gpu/test_from_cudf.py
Normal file
172
tests/python-gpu/test_from_cudf.py
Normal file
@ -0,0 +1,172 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
|
||||
def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN):
|
||||
'''Test constructing DMatrix from cudf'''
|
||||
import cudf
|
||||
import pandas as pd
|
||||
|
||||
kRows = 80
|
||||
kCols = 3
|
||||
|
||||
na = np.random.randn(kRows, kCols)
|
||||
na[:, 0:2] = na[:, 0:2].astype(input_type)
|
||||
|
||||
na[5, 0] = missing
|
||||
na[3, 1] = missing
|
||||
|
||||
pa = pd.DataFrame({'0': na[:, 0],
|
||||
'1': na[:, 1],
|
||||
'2': na[:, 2].astype(np.int32)})
|
||||
|
||||
np_label = np.random.randn(kRows).astype(input_type)
|
||||
pa_label = pd.DataFrame(np_label)
|
||||
|
||||
cd = cudf.from_pandas(pa)
|
||||
cd_label = cudf.from_pandas(pa_label).iloc[:, 0]
|
||||
|
||||
dtrain = DMatrixT(cd, missing=missing, label=cd_label)
|
||||
assert dtrain.num_col() == kCols
|
||||
assert dtrain.num_row() == kRows
|
||||
|
||||
|
||||
def _test_from_cudf(DMatrixT):
|
||||
'''Test constructing DMatrix from cudf'''
|
||||
import cudf
|
||||
dmatrix_from_cudf(np.float32, DMatrixT, np.NAN)
|
||||
dmatrix_from_cudf(np.float64, DMatrixT, np.NAN)
|
||||
|
||||
dmatrix_from_cudf(np.int8, DMatrixT, 2)
|
||||
dmatrix_from_cudf(np.int32, DMatrixT, -2)
|
||||
dmatrix_from_cudf(np.int64, DMatrixT, -3)
|
||||
|
||||
cd = cudf.DataFrame({'x': [1, 2, 3], 'y': [0.1, 0.2, 0.3]})
|
||||
dtrain = DMatrixT(cd)
|
||||
|
||||
assert dtrain.feature_names == ['x', 'y']
|
||||
assert dtrain.feature_types == ['int', 'float']
|
||||
|
||||
series = cudf.DataFrame({'x': [1, 2, 3]}).iloc[:, 0]
|
||||
assert isinstance(series, cudf.Series)
|
||||
dtrain = DMatrixT(series)
|
||||
|
||||
assert dtrain.feature_names == ['x']
|
||||
assert dtrain.feature_types == ['int']
|
||||
|
||||
with pytest.raises(Exception):
|
||||
dtrain = DMatrixT(cd, label=cd)
|
||||
|
||||
# Test when number of elements is less than 8
|
||||
X = cudf.DataFrame({'x': cudf.Series([0, 1, 2, np.NAN, 4],
|
||||
dtype=np.int32)})
|
||||
dtrain = DMatrixT(X)
|
||||
assert dtrain.num_col() == 1
|
||||
assert dtrain.num_row() == 5
|
||||
|
||||
# Boolean is not supported.
|
||||
X_boolean = cudf.DataFrame({'x': cudf.Series([True, False])})
|
||||
with pytest.raises(Exception):
|
||||
dtrain = DMatrixT(X_boolean)
|
||||
|
||||
y_boolean = cudf.DataFrame({
|
||||
'x': cudf.Series([True, False, True, True, True])})
|
||||
with pytest.raises(Exception):
|
||||
dtrain = DMatrixT(X_boolean, label=y_boolean)
|
||||
|
||||
|
||||
def _test_cudf_training(DMatrixT):
|
||||
from cudf import DataFrame as df
|
||||
import pandas as pd
|
||||
np.random.seed(1)
|
||||
X = pd.DataFrame(np.random.randn(50, 10))
|
||||
y = pd.DataFrame(np.random.randn(50))
|
||||
weights = np.random.random(50) + 1.0
|
||||
cudf_weights = df.from_pandas(pd.DataFrame(weights))
|
||||
base_margin = np.random.random(50)
|
||||
cudf_base_margin = df.from_pandas(pd.DataFrame(base_margin))
|
||||
|
||||
evals_result_cudf = {}
|
||||
dtrain_cudf = DMatrixT(df.from_pandas(X), df.from_pandas(y), weight=cudf_weights,
|
||||
base_margin=cudf_base_margin)
|
||||
params = {'gpu_id': 0, 'tree_method': 'gpu_hist'}
|
||||
xgb.train(params, dtrain_cudf, evals=[(dtrain_cudf, "train")],
|
||||
evals_result=evals_result_cudf)
|
||||
evals_result_np = {}
|
||||
dtrain_np = xgb.DMatrix(X, y, weight=weights, base_margin=base_margin)
|
||||
xgb.train(params, dtrain_np, evals=[(dtrain_np, "train")],
|
||||
evals_result=evals_result_np)
|
||||
assert np.array_equal(evals_result_cudf["train"]["rmse"], evals_result_np["train"]["rmse"])
|
||||
|
||||
|
||||
def _test_cudf_metainfo(DMatrixT):
|
||||
from cudf import DataFrame as df
|
||||
import pandas as pd
|
||||
n = 100
|
||||
X = np.random.random((n, 2))
|
||||
dmat_cudf = DMatrixT(df.from_pandas(pd.DataFrame(X)))
|
||||
dmat = xgb.DMatrix(X)
|
||||
floats = np.random.random(n)
|
||||
uints = np.array([4, 2, 8]).astype("uint32")
|
||||
cudf_floats = df.from_pandas(pd.DataFrame(floats))
|
||||
cudf_uints = df.from_pandas(pd.DataFrame(uints))
|
||||
dmat.set_float_info('weight', floats)
|
||||
dmat.set_float_info('label', floats)
|
||||
dmat.set_float_info('base_margin', floats)
|
||||
dmat.set_uint_info('group', uints)
|
||||
dmat_cudf.set_interface_info('weight', cudf_floats)
|
||||
dmat_cudf.set_interface_info('label', cudf_floats)
|
||||
dmat_cudf.set_interface_info('base_margin', cudf_floats)
|
||||
dmat_cudf.set_interface_info('group', cudf_uints)
|
||||
|
||||
# Test setting info with cudf DataFrame
|
||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||
dmat_cudf.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
||||
|
||||
# Test setting info with cudf Series
|
||||
dmat_cudf.set_interface_info('weight', cudf_floats[cudf_floats.columns[0]])
|
||||
dmat_cudf.set_interface_info('label', cudf_floats[cudf_floats.columns[0]])
|
||||
dmat_cudf.set_interface_info('base_margin', cudf_floats[cudf_floats.columns[0]])
|
||||
dmat_cudf.set_interface_info('group', cudf_uints[cudf_uints.columns[0]])
|
||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||
dmat_cudf.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
||||
|
||||
|
||||
class TestFromColumnar:
|
||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
||||
Arrow specification.'''
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_simple_dmatrix_from_cudf(self):
|
||||
_test_from_cudf(xgb.DMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_device_dmatrix_from_cudf(self):
|
||||
_test_from_cudf(xgb.DeviceQuantileDMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_cudf_training_simple_dmatrix(self):
|
||||
_test_cudf_training(xgb.DMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_cudf_training_device_dmatrix(self):
|
||||
_test_cudf_training(xgb.DeviceQuantileDMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_cudf_metainfo_simple_dmatrix(self):
|
||||
_test_cudf_metainfo(xgb.DMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_cudf_metainfo_device_dmatrix(self):
|
||||
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
|
||||
@ -7,7 +7,7 @@ sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
|
||||
def dmatrix_from_cupy(input_type, missing=np.NAN):
|
||||
def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
||||
'''Test constructing DMatrix from cupy'''
|
||||
import cupy as cp
|
||||
|
||||
@ -19,82 +19,106 @@ def dmatrix_from_cupy(input_type, missing=np.NAN):
|
||||
X[5, 0] = missing
|
||||
X[3, 1] = missing
|
||||
y = cp.random.randn(kRows).astype(dtype=input_type)
|
||||
dtrain = xgb.DMatrix(X, missing=missing, label=y)
|
||||
dtrain = DMatrixT(X, missing=missing, label=y)
|
||||
assert dtrain.num_col() == kCols
|
||||
assert dtrain.num_row() == kRows
|
||||
return dtrain
|
||||
|
||||
|
||||
def _test_from_cupy(DMatrixT):
|
||||
'''Test constructing DMatrix from cupy'''
|
||||
import cupy as cp
|
||||
dmatrix_from_cupy(np.float32, DMatrixT, np.NAN)
|
||||
dmatrix_from_cupy(np.float64, DMatrixT, np.NAN)
|
||||
|
||||
dmatrix_from_cupy(np.uint8, DMatrixT, 2)
|
||||
dmatrix_from_cupy(np.uint32, DMatrixT, 3)
|
||||
dmatrix_from_cupy(np.uint64, DMatrixT, 4)
|
||||
|
||||
dmatrix_from_cupy(np.int8, DMatrixT, 2)
|
||||
dmatrix_from_cupy(np.int32, DMatrixT, -2)
|
||||
dmatrix_from_cupy(np.int64, DMatrixT, -3)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
X = cp.random.randn(2, 2, dtype="float32")
|
||||
dtrain = DMatrixT(X, label=X)
|
||||
|
||||
|
||||
def _test_cupy_training(DMatrixT):
|
||||
import cupy as cp
|
||||
np.random.seed(1)
|
||||
cp.random.seed(1)
|
||||
X = cp.random.randn(50, 10, dtype="float32")
|
||||
y = cp.random.randn(50, dtype="float32")
|
||||
weights = np.random.random(50) + 1
|
||||
cupy_weights = cp.array(weights)
|
||||
base_margin = np.random.random(50)
|
||||
cupy_base_margin = cp.array(base_margin)
|
||||
|
||||
evals_result_cupy = {}
|
||||
dtrain_cp = DMatrixT(X, y, weight=cupy_weights, base_margin=cupy_base_margin)
|
||||
params = {'gpu_id': 0, 'nthread': 1, 'tree_method': 'gpu_hist'}
|
||||
xgb.train(params, dtrain_cp, evals=[(dtrain_cp, "train")],
|
||||
evals_result=evals_result_cupy)
|
||||
evals_result_np = {}
|
||||
dtrain_np = xgb.DMatrix(cp.asnumpy(X), cp.asnumpy(y), weight=weights,
|
||||
base_margin=base_margin)
|
||||
xgb.train(params, dtrain_np, evals=[(dtrain_np, "train")],
|
||||
evals_result=evals_result_np)
|
||||
assert np.array_equal(evals_result_cupy["train"]["rmse"], evals_result_np["train"]["rmse"])
|
||||
|
||||
|
||||
def _test_cupy_metainfo(DMatrixT):
|
||||
import cupy as cp
|
||||
n = 100
|
||||
X = np.random.random((n, 2))
|
||||
dmat_cupy = DMatrixT(cp.array(X))
|
||||
dmat = xgb.DMatrix(X)
|
||||
floats = np.random.random(n)
|
||||
uints = np.array([4, 2, 8]).astype("uint32")
|
||||
cupy_floats = cp.array(floats)
|
||||
cupy_uints = cp.array(uints)
|
||||
dmat.set_float_info('weight', floats)
|
||||
dmat.set_float_info('label', floats)
|
||||
dmat.set_float_info('base_margin', floats)
|
||||
dmat.set_uint_info('group', uints)
|
||||
dmat_cupy.set_interface_info('weight', cupy_floats)
|
||||
dmat_cupy.set_interface_info('label', cupy_floats)
|
||||
dmat_cupy.set_interface_info('base_margin', cupy_floats)
|
||||
dmat_cupy.set_interface_info('group', cupy_uints)
|
||||
|
||||
# Test setting info with cupy
|
||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cupy.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cupy.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||
dmat_cupy.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))
|
||||
|
||||
|
||||
class TestFromArrayInterface:
|
||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
||||
Arrow specification.'''
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_from_cupy(self):
|
||||
'''Test constructing DMatrix from cupy'''
|
||||
import cupy as cp
|
||||
dmatrix_from_cupy(np.float32, np.NAN)
|
||||
dmatrix_from_cupy(np.float64, np.NAN)
|
||||
|
||||
dmatrix_from_cupy(np.uint8, 2)
|
||||
dmatrix_from_cupy(np.uint32, 3)
|
||||
dmatrix_from_cupy(np.uint64, 4)
|
||||
|
||||
dmatrix_from_cupy(np.int8, 2)
|
||||
dmatrix_from_cupy(np.int32, -2)
|
||||
dmatrix_from_cupy(np.int64, -3)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
X = cp.random.randn(2, 2, dtype="float32")
|
||||
dtrain = xgb.DMatrix(X, label=X)
|
||||
def test_simple_dmat_from_cupy(self):
|
||||
_test_from_cupy(xgb.DMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_cupy_training(self):
|
||||
import cupy as cp
|
||||
np.random.seed(1)
|
||||
cp.random.seed(1)
|
||||
X = cp.random.randn(50, 10, dtype="float32")
|
||||
y = cp.random.randn(50, dtype="float32")
|
||||
weights = np.random.random(50) + 1
|
||||
cupy_weights = cp.array(weights)
|
||||
base_margin = np.random.random(50)
|
||||
cupy_base_margin = cp.array(base_margin)
|
||||
|
||||
evals_result_cupy = {}
|
||||
dtrain_cp = xgb.DMatrix(X, y, weight=cupy_weights, base_margin=cupy_base_margin)
|
||||
params = {'gpu_id': 0, 'nthread': 1}
|
||||
xgb.train(params, dtrain_cp, evals=[(dtrain_cp, "train")],
|
||||
evals_result=evals_result_cupy)
|
||||
evals_result_np = {}
|
||||
dtrain_np = xgb.DMatrix(cp.asnumpy(X), cp.asnumpy(y), weight=weights,
|
||||
base_margin=base_margin)
|
||||
xgb.train(params, dtrain_np, evals=[(dtrain_np, "train")],
|
||||
evals_result=evals_result_np)
|
||||
assert np.array_equal(evals_result_cupy["train"]["rmse"], evals_result_np["train"]["rmse"])
|
||||
def test_device_dmat_from_cupy(self):
|
||||
_test_from_cupy(xgb.DeviceQuantileDMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_cupy_metainfo(self):
|
||||
import cupy as cp
|
||||
n = 100
|
||||
X = np.random.random((n, 2))
|
||||
dmat_cupy = xgb.DMatrix(X)
|
||||
dmat = xgb.DMatrix(X)
|
||||
floats = np.random.random(n)
|
||||
uints = np.array([4, 2, 8]).astype("uint32")
|
||||
cupy_floats = cp.array(floats)
|
||||
cupy_uints = cp.array(uints)
|
||||
dmat.set_float_info('weight', floats)
|
||||
dmat.set_float_info('label', floats)
|
||||
dmat.set_float_info('base_margin', floats)
|
||||
dmat.set_uint_info('group', uints)
|
||||
dmat_cupy.set_interface_info('weight', cupy_floats)
|
||||
dmat_cupy.set_interface_info('label', cupy_floats)
|
||||
dmat_cupy.set_interface_info('base_margin', cupy_floats)
|
||||
dmat_cupy.set_interface_info('group', cupy_uints)
|
||||
def test_cupy_training_device_dmat(self):
|
||||
_test_cupy_training(xgb.DeviceQuantileDMatrix)
|
||||
|
||||
# Test setting info with cupy
|
||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cupy.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cupy.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||
dmat_cupy.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_cupy_training_simple_dmat(self):
|
||||
_test_cupy_training(xgb.DMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_cupy_metainfo_simple_dmat(self):
|
||||
_test_cupy_metainfo(xgb.DMatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_cupy_metainfo_device_dmat(self):
|
||||
_test_cupy_metainfo(xgb.DeviceQuantileDMatrix)
|
||||
|
||||
@ -2,9 +2,10 @@ import numpy as np
|
||||
import sys
|
||||
import unittest
|
||||
import pytest
|
||||
import xgboost
|
||||
import xgboost as xgb
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
from regression_test_utilities import run_suite, parameter_combinations, \
|
||||
assert_results_non_increasing
|
||||
|
||||
@ -40,6 +41,19 @@ class TestGPU(unittest.TestCase):
|
||||
cpu_results = run_suite(param, select_datasets=datasets)
|
||||
assert_gpu_results(cpu_results, gpu_results)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_gpu_hist_device_dmatrix(self):
|
||||
# DeviceDMatrix does not currently accept sparse formats
|
||||
device_dmatrix_datasets = ["Boston", "Cancer", "Digits"]
|
||||
for param in test_param:
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
gpu_results_device_dmatrix = run_suite(param, select_datasets=device_dmatrix_datasets,
|
||||
DMatrixT=xgb.DeviceQuantileDMatrix,
|
||||
dmatrix_params={'max_bin': param['max_bin']})
|
||||
assert_results_non_increasing(gpu_results_device_dmatrix, 1e-2)
|
||||
gpu_results = run_suite(param, select_datasets=device_dmatrix_datasets)
|
||||
assert_gpu_results(gpu_results, gpu_results_device_dmatrix)
|
||||
|
||||
# NOTE(rongou): Because the `Boston` dataset is too small, this only tests external memory mode
|
||||
# with a single page. To test multiple pages, set DMatrix::kPageSize to, say, 1024.
|
||||
def test_external_memory(self):
|
||||
@ -61,20 +75,20 @@ class TestGPU(unittest.TestCase):
|
||||
X = np.empty((kRows, kCols))
|
||||
y = np.empty((kRows))
|
||||
|
||||
dtrain = xgboost.DMatrix(X, y)
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
|
||||
bst = xgboost.train({'verbosity': 2,
|
||||
'tree_method': 'gpu_hist',
|
||||
'gpu_id': 0},
|
||||
dtrain,
|
||||
verbose_eval=True,
|
||||
num_boost_round=6,
|
||||
evals=[(dtrain, 'Train')])
|
||||
bst = xgb.train({'verbosity': 2,
|
||||
'tree_method': 'gpu_hist',
|
||||
'gpu_id': 0},
|
||||
dtrain,
|
||||
verbose_eval=True,
|
||||
num_boost_round=6,
|
||||
evals=[(dtrain, 'Train')])
|
||||
|
||||
kRows = 100
|
||||
X = np.random.randn(kRows, kCols)
|
||||
|
||||
dtest = xgboost.DMatrix(X)
|
||||
dtest = xgb.DMatrix(X)
|
||||
predictions = bst.predict(dtest)
|
||||
np.testing.assert_allclose(predictions, 0.5, 1e-6)
|
||||
|
||||
|
||||
@ -84,7 +84,8 @@ def get_weights_regression(min_weight, max_weight):
|
||||
return X, y, w
|
||||
|
||||
|
||||
def train_dataset(dataset, param_in, num_rounds=10, scale_features=False):
|
||||
def train_dataset(dataset, param_in, num_rounds=10, scale_features=False, DMatrixT=xgb.DMatrix,
|
||||
dmatrix_params={}):
|
||||
param = param_in.copy()
|
||||
param["objective"] = dataset.objective
|
||||
if dataset.objective == "multi:softmax":
|
||||
@ -99,10 +100,13 @@ def train_dataset(dataset, param_in, num_rounds=10, scale_features=False):
|
||||
if dataset.use_external_memory:
|
||||
np.savetxt('tmptmp_1234.csv', np.hstack((dataset.y.reshape(len(dataset.y), 1), X)),
|
||||
delimiter=',')
|
||||
dtrain = xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_',
|
||||
dtrain = DMatrixT('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_',
|
||||
weight=dataset.w)
|
||||
elif DMatrixT is xgb.DeviceQuantileDMatrix:
|
||||
import cupy as cp
|
||||
dtrain = DMatrixT(cp.array(X), dataset.y, weight=dataset.w, **dmatrix_params)
|
||||
else:
|
||||
dtrain = xgb.DMatrix(X, dataset.y, weight=dataset.w)
|
||||
dtrain = DMatrixT(X, dataset.y, weight=dataset.w, **dmatrix_params)
|
||||
|
||||
print("Training on dataset: " + dataset.name, file=sys.stderr)
|
||||
print("Using parameters: " + str(param), file=sys.stderr)
|
||||
@ -139,7 +143,8 @@ def parameter_combinations(variable_param):
|
||||
return result
|
||||
|
||||
|
||||
def run_suite(param, num_rounds=10, select_datasets=None, scale_features=False):
|
||||
def run_suite(param, num_rounds=10, select_datasets=None, scale_features=False,
|
||||
DMatrixT=xgb.DMatrix, dmatrix_params={}):
|
||||
"""
|
||||
Run the given parameters on a range of datasets. Objective and eval metric will be automatically set
|
||||
"""
|
||||
@ -162,7 +167,8 @@ def run_suite(param, num_rounds=10, select_datasets=None, scale_features=False):
|
||||
for d in datasets:
|
||||
if select_datasets is None or d.name in select_datasets:
|
||||
results.append(
|
||||
train_dataset(d, param, num_rounds=num_rounds, scale_features=scale_features))
|
||||
train_dataset(d, param, num_rounds=num_rounds, scale_features=scale_features,
|
||||
DMatrixT=DMatrixT, dmatrix_params=dmatrix_params))
|
||||
return results
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user