Support dmatrix construction from cupy array (#5206)

This commit is contained in:
Rory Mitchell 2020-01-22 13:15:27 +13:00 committed by GitHub
parent 2a071cebc5
commit 9c56480c61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 522 additions and 158 deletions

View File

@ -115,6 +115,7 @@ except ImportError:
DT_INSTALLED = False
# cudf
try:
from cudf import DataFrame as CUDF_DataFrame
from cudf import Series as CUDF_Series

View File

@ -24,7 +24,6 @@ from .compat import (
os_fspath, os_PathLike)
from .libpath import find_lib_path
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
c_bst_ulong = ctypes.c_uint64
@ -41,6 +40,7 @@ class EarlyStopException(Exception):
best_iteration : int
The best iteration stopped.
"""
def __init__(self, best_iteration):
super(EarlyStopException, self).__init__()
self.best_iteration = best_iteration
@ -231,50 +231,16 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)
def _use_columnar_initializer(data):
"""Whether should we use columnar format initializer (pass data in as json
string). Currently cudf is the only valid option. For other dataframe
types, use their sepcific API instead.
"""
if CUDF_INSTALLED and (isinstance(data, (CUDF_DataFrame, CUDF_Series))):
return True
return False
def _extract_interface_from_cudf_series(data):
"""This returns the array interface from the cudf series. This function
should be upstreamed to cudf.
"""
interface = data.__cuda_array_interface__
if data.has_null_mask:
interface['mask'] = interface['mask'].__cuda_array_interface__
return interface
def _extract_interface_from_cudf(df):
"""This function should be upstreamed to cudf."""
if not _use_columnar_initializer(df):
raise ValueError('Only cudf is supported for initializing as json ' +
'columnar format. For other libraries please ' +
'refer to specific API.')
array_interfaces = []
if isinstance(df, CUDF_DataFrame):
for col in df.columns:
array_interfaces.append(
_extract_interface_from_cudf_series(df[col]))
else:
array_interfaces.append(_extract_interface_from_cudf_series(df))
interfaces = bytes(json.dumps(array_interfaces, indent=2), 'utf-8')
return interfaces
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
# Either object has cuda array interface or contains columns with interfaces
def _has_cuda_array_interface(data):
return hasattr(data, '__cuda_array_interface__') or (
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))
def _maybe_pandas_data(data, feature_names, feature_types):
"""Extract internal data from pd.DataFrame for DMatrix data"""
@ -433,7 +399,7 @@ class DMatrix(object):
"""Parameters
----------
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
dt.Frame/cudf.DataFrame
dt.Frame/cudf.DataFrame/cupy.array
Data source of DMatrix.
When data is string or os.PathLike type, it represents the path
libsvm format txt file, csv file (by specifying uri parameter
@ -500,8 +466,10 @@ class DMatrix(object):
self._init_from_npy2d(data, missing, nthread)
elif isinstance(data, DataTable):
self._init_from_dt(data, nthread)
elif _use_columnar_initializer(data):
self._init_from_columnar(data, missing)
elif hasattr(data, "__cuda_array_interface__"):
self._init_from_array_interface(data, missing, nthread)
elif CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
self._init_from_array_interface_columns(data, missing, nthread)
else:
try:
csr = scipy.sparse.csr_matrix(data)
@ -596,7 +564,8 @@ class DMatrix(object):
ptrs[icol] = ctypes.c_void_p(ptr)
else:
# datatable<=0.8.0
from datatable.internal import frame_column_data_r # pylint: disable=no-name-in-module,import-error
from datatable.internal import \
frame_column_data_r # pylint: disable=no-name-in-module,import-error
for icol in range(data.ncols):
ptrs[icol] = frame_column_data_r(data, icol)
@ -614,16 +583,38 @@ class DMatrix(object):
nthread))
self.handle = handle
def _init_from_columnar(self, df, missing):
def _init_from_array_interface_columns(self, df, missing, nthread):
"""Initialize DMatrix from columnar memory format."""
interfaces = _extract_interface_from_cudf(df)
interfaces = []
for col in df:
interface = df[col].__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interfaces.append(interface)
handle = ctypes.c_void_p()
has_missing = missing is not None
missing = missing if has_missing else np.nan
missing = missing if missing is not None else np.nan
nthread = nthread if nthread is not None else 1
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaces(
interfaces, ctypes.c_int32(has_missing),
ctypes.c_float(missing), ctypes.byref(handle)))
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
interfaces_str,
ctypes.c_float(missing), ctypes.c_int(nthread), ctypes.byref(handle)))
self.handle = handle
def _init_from_array_interface(self, data, missing, nthread):
"""Initialize DMatrix from cupy ndarray."""
interface = data.__cuda_array_interface__
if 'mask' in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
handle = ctypes.c_void_p()
missing = missing if missing is not None else np.nan
nthread = nthread if nthread is not None else 1
_check_call(
_LIB.XGDMatrixCreateFromArrayInterface(
interface_str,
ctypes.c_float(missing), ctypes.c_int(nthread), ctypes.byref(handle)))
self.handle = handle
def __del__(self):
@ -694,11 +685,18 @@ class DMatrix(object):
c_bst_ulong(len(data))))
def set_interface_info(self, field, data):
"""Set info type peoperty into DMatrix."""
interfaces = _extract_interface_from_cudf(data)
"""Set info type property into DMatrix."""
# If we are passed a dataframe, extract the series
if isinstance(data, CUDF_DataFrame):
if len(data.columns) != 1:
raise ValueError('Expecting meta-info to contain a single column')
data = data[data.columns[0]]
interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), 'utf-8')
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
c_str(field),
interfaces))
interface))
def set_float_info_npy2d(self, field, data):
"""Set float type property into the DMatrix
@ -779,7 +777,7 @@ class DMatrix(object):
"""
if isinstance(label, np.ndarray):
self.set_label_npy2d(label)
elif _use_columnar_initializer(label):
elif _has_cuda_array_interface(label):
self.set_interface_info('label', label)
else:
self.set_float_info('label', label)
@ -812,7 +810,7 @@ class DMatrix(object):
"""
if isinstance(weight, np.ndarray):
self.set_weight_npy2d(weight)
elif _use_columnar_initializer(weight):
elif _has_cuda_array_interface(weight):
self.set_interface_info('weight', weight)
else:
self.set_float_info('weight', weight)
@ -849,7 +847,7 @@ class DMatrix(object):
margin: array like
Prediction margin of each datapoint
"""
if _use_columnar_initializer(margin):
if _has_cuda_array_interface(margin):
self.set_interface_info('base_margin', margin)
else:
self.set_float_info('base_margin', margin)
@ -862,7 +860,7 @@ class DMatrix(object):
group : array like
Group size of each group
"""
if _use_columnar_initializer(group):
if _has_cuda_array_interface(group):
self.set_interface_info('group', group)
else:
self.set_uint_info('group', group)
@ -1073,7 +1071,7 @@ class Booster(object):
ctypes.byref(self.handle)))
if isinstance(params, dict) and \
'validate_parameters' not in params.keys():
'validate_parameters' not in params.keys():
params['validate_parameters'] = 1
self.set_param(params or {})
if (params is not None) and ('booster' in params):

View File

@ -206,12 +206,24 @@ int XGDMatrixCreateFromDataIter(
}
#ifndef XGBOOST_USE_CUDA
XGB_DLL int XGDMatrixCreateFromArrayInterfaces(
char const* c_json_strs, bst_int has_missing, bst_float missing, DMatrixHandle* out) {
XGB_DLL int XGDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
bst_float missing,
int nthread,
DMatrixHandle* out) {
API_BEGIN();
LOG(FATAL) << "Xgboost not compiled with cuda";
API_END();
}
XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
bst_float missing,
int nthread,
DMatrixHandle* out) {
API_BEGIN();
LOG(FATAL) << "Xgboost not compiled with cuda";
API_END();
}
#endif
XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,

View File

@ -7,14 +7,27 @@
#include "../data/device_adapter.cuh"
namespace xgboost {
XGB_DLL int XGDMatrixCreateFromArrayInterfaces(char const* c_json_strs,
bst_int has_missing,
bst_float missing,
DMatrixHandle* out) {
XGB_DLL int XGDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
bst_float missing,
int nthread,
DMatrixHandle* out) {
API_BEGIN();
std::string json_str{c_json_strs};
data::CudfAdapter adapter(json_str);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, 1));
*out =
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
API_END();
}
XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
bst_float missing, int nthread,
DMatrixHandle* out) {
API_BEGIN();
std::string json_str{c_json_strs};
data::CupyAdapter adapter(json_str);
*out =
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
API_END();
}
} // namespace xgboost

View File

@ -1,14 +1,15 @@
/*!
* Copyright 2019 by Contributors
* \file columnar.h
* \file array_interface.h
* \brief Basic structure holding a reference to arrow columnar data format.
*/
#ifndef XGBOOST_DATA_COLUMNAR_H_
#define XGBOOST_DATA_COLUMNAR_H_
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
#include <cinttypes>
#include <map>
#include <string>
#include <utility>
#include "xgboost/data.h"
#include "xgboost/json.h"
@ -18,7 +19,7 @@
namespace xgboost {
// Common errors in parsing columnar format.
struct ColumnarErrors {
struct ArrayInterfaceErrors {
static char const* Contigious() {
return "Memory should be contigious.";
}
@ -119,15 +120,12 @@ class ArrayInterfaceHandler {
if (array.find("version") == array.cend()) {
LOG(FATAL) << "Missing `version' field for array interface";
}
auto version = get<Integer const>(array.at("version"));
CHECK_EQ(version, 1) << ColumnarErrors::Version();
if (array.find("typestr") == array.cend()) {
LOG(FATAL) << "Missing `typestr' field for array interface";
}
auto typestr = get<String const>(array.at("typestr"));
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
if (array.find("shape") == array.cend()) {
LOG(FATAL) << "Missing `shape' field for array interface";
@ -149,7 +147,7 @@ class ArrayInterfaceHandler {
auto p_mask = GetPtrFromArrayData<RBitField8::value_type*>(j_mask);
auto j_shape = get<Array const>(j_mask.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
CHECK_EQ(j_shape.size(), 1) << ArrayInterfaceErrors::Dimension(1);
auto typestr = get<String const>(j_mask.at("typestr"));
// For now this is just 1, we can support different size of interger in mask.
int64_t const type_length = typestr.at(2) - 48;
@ -178,8 +176,8 @@ class ArrayInterfaceHandler {
if (j_mask.find("strides") != j_mask.cend()) {
auto strides = get<Array const>(column.at("strides"));
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ColumnarErrors::Contigious();
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contigious();
}
s_mask = {p_mask, span_size};
@ -188,18 +186,28 @@ class ArrayInterfaceHandler {
return 0;
}
static size_t ExtractLength(std::map<std::string, Json> const& column) {
static std::pair<size_t, size_t> ExtractShape(
std::map<std::string, Json> const& column) {
auto j_shape = get<Array const>(column.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
auto typestr = get<String const>(column.at("typestr"));
if (column.find("strides") != column.cend()) {
auto strides = get<Array const>(column.at("strides"));
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), typestr.at(2) - '0')
<< ColumnarErrors::Contigious();
if (!IsA<Null>(column.at("strides"))) {
auto strides = get<Array const>(column.at("strides"));
CHECK_EQ(strides.size(), j_shape.size())
<< ArrayInterfaceErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), typestr.at(2) - '0')
<< ArrayInterfaceErrors::Contigious();
}
}
return static_cast<size_t>(get<Integer const>(j_shape.at(0)));
if (j_shape.size() == 1) {
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))), 1};
} else {
CHECK_EQ(j_shape.size(), 2)
<< "Only 1D or 2-D arrays currently supported.";
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))),
static_cast<size_t>(get<Integer const>(j_shape.at(1)))};
}
}
template <typename T>
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
@ -212,25 +220,27 @@ class ArrayInterfaceHandler {
<< "Input data type and typestr mismatch. typestr: " << typestr;
auto length = ExtractLength(column);
auto shape = ExtractShape(column);
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
return common::Span<T>{p_data, length};
return common::Span<T>{p_data, shape.first * shape.second};
}
};
// A view over __array_interface__
class Columnar {
class ArrayInterface {
using mask_type = unsigned char;
using index_type = int32_t;
public:
Columnar() = default;
explicit Columnar(std::map<std::string, Json> const& column) {
ArrayInterface() = default;
explicit ArrayInterface(std::map<std::string, Json> const& column) {
ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
CHECK(data) << "Column is null";
size = ArrayInterfaceHandler::ExtractLength(column);
auto shape = ArrayInterfaceHandler::ExtractShape(column);
num_rows = shape.first;
num_cols = shape.second;
common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
@ -238,7 +248,7 @@ class Columnar {
valid = RBitField8(s_mask);
if (s_mask.data()) {
CHECK_EQ(n_bits, size)
CHECK_EQ(n_bits, num_rows)
<< "Shape of bit mask doesn't match data shape. "
<< "XGBoost doesn't support internal broadcasting.";
}
@ -271,7 +281,7 @@ class Columnar {
} else if (type[1] == 'u' && type[2] == '8') {
return;
} else {
LOG(FATAL) << ColumnarErrors::UnSupportedType(type);
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(type);
return;
}
}
@ -304,10 +314,11 @@ class Columnar {
}
RBitField8 valid;
int32_t size;
int32_t num_rows;
int32_t num_cols;
void* data;
char type[3];
};
} // namespace xgboost
#endif // XGBOOST_DATA_COLUMNAR_H_
#endif // XGBOOST_DATA_ARRAY_INTERFACE_H_

View File

@ -7,14 +7,14 @@
#include "xgboost/data.h"
#include "xgboost/logging.h"
#include "xgboost/json.h"
#include "columnar.h"
#include "array_interface.h"
#include "../common/device_helpers.cuh"
#include "device_adapter.cuh"
#include "simple_dmatrix.h"
namespace xgboost {
void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<float>* out) {
void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
auto SetDeviceToPtr = [](void* ptr) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
@ -22,43 +22,42 @@ void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<fl
dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device;
};
Columnar foreign_column(column);
auto ptr_device = SetDeviceToPtr(foreign_column.data);
auto ptr_device = SetDeviceToPtr(column.data);
out->SetDevice(ptr_device);
out->Resize(foreign_column.size);
out->Resize(column.num_rows);
auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
dh::LaunchN(ptr_device, foreign_column.size, [=] __device__(size_t idx) {
p_dst[idx] = foreign_column.GetElement(idx);
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
p_dst[idx] = column.GetElement(idx);
});
}
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
auto const& j_arr = get<Array>(j_interface);
CHECK_EQ(j_arr.size(), 1) << "MetaInfo: " << c_key << ". " << ColumnarErrors::Dimension(1);;
auto const& j_arr_obj = get<Object const>(j_arr[0]);
std::string key {c_key};
ArrayInterfaceHandler::Validate(j_arr_obj);
if (j_arr_obj.find("mask") != j_arr_obj.cend()) {
LOG(FATAL) << "Meta info " << key << " should be dense, found validity mask";
}
auto const& typestr = get<String const>(j_arr_obj.at("typestr"));
CHECK_EQ(j_arr.size(), 1)
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
ArrayInterface array_interface(get<Object const>(j_arr[0]));
std::string key{c_key};
CHECK(!array_interface.valid.Data())
<< "Meta info " << key << " should be dense, found validity mask";
CHECK_EQ(array_interface.num_cols, 1)
<< "Meta info should be a single column.";
if (key == "label") {
CopyInfoImpl(j_arr_obj, &labels_);
CopyInfoImpl(array_interface, &labels_);
} else if (key == "weight") {
CopyInfoImpl(j_arr_obj, &weights_);
CopyInfoImpl(array_interface, &weights_);
} else if (key == "base_margin") {
CopyInfoImpl(j_arr_obj, &base_margin_);
CopyInfoImpl(array_interface, &base_margin_);
} else if (key == "group") {
// Ranking is not performed on device.
auto s_data = ArrayInterfaceHandler::ExtractData<uint32_t>(j_arr_obj);
thrust::device_ptr<uint32_t> p_src {s_data.data()};
thrust::device_ptr<uint32_t> p_src{
reinterpret_cast<uint32_t*>(array_interface.data)};
auto length = s_data.size();
auto length = array_interface.num_rows;
group_ptr_.resize(length + 1);
group_ptr_[0] = 0;
thrust::copy(p_src, p_src + length, group_ptr_.begin() + 1);
@ -82,4 +81,7 @@ DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
template DMatrix* DMatrix::Create<data::CudfAdapter>(
data::CudfAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::CupyAdapter>(
data::CupyAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
} // namespace xgboost

View File

@ -7,9 +7,9 @@
#include <limits>
#include <memory>
#include <string>
#include "columnar.h"
#include "adapter.h"
#include "../common/device_helpers.cuh"
#include "adapter.h"
#include "array_interface.h"
namespace xgboost {
namespace data {
@ -17,12 +17,13 @@ namespace data {
class CudfAdapterBatch : public detail::NoMetaInfo {
public:
CudfAdapterBatch() = default;
CudfAdapterBatch(common::Span<Columnar> columns,
CudfAdapterBatch(common::Span<ArrayInterface> columns,
common::Span<size_t> column_ptr, size_t num_elements)
: columns_(columns),column_ptr_(column_ptr), num_elements(num_elements) {}
size_t Size()const { return num_elements; }
__device__ COOTuple GetElement(size_t idx)const
{
: columns_(columns),
column_ptr_(column_ptr),
num_elements(num_elements) {}
size_t Size() const { return num_elements; }
__device__ COOTuple GetElement(size_t idx) const {
size_t column_idx =
dh::UpperBound(column_ptr_.data(), column_ptr_.size(), idx) - 1;
auto& column = columns_[column_idx];
@ -34,22 +35,23 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
}
private:
common::Span<Columnar> columns_;
common::Span<ArrayInterface> columns_;
common::Span<size_t> column_ptr_;
size_t num_elements;
};
/*!
* Please be careful that, in official specification, the only three required fields are
* `shape', `version' and `typestr'. Any other is optional, including `data'. But here
* we have one additional requirements for input data:
* Please be careful that, in official specification, the only three required
* fields are `shape', `version' and `typestr'. Any other is optional,
* including `data'. But here we have one additional requirements for input
* data:
*
* - `data' field is required, passing in an empty dataset is not accepted, as most (if
* not all) of our algorithms don't have test for empty dataset. An error is better
* than a crash.
* - `data' field is required, passing in an empty dataset is not accepted, as
* most (if not all) of our algorithms don't have test for empty dataset. An
* error is better than a crash.
*
* What if invalid value from dataframe is 0 but I specify missing=NaN in XGBoost? Since
* validity mask is ignored, all 0s are preserved in XGBoost.
* What if invalid value from dataframe is 0 but I specify missing=NaN in
* XGBoost? Since validity mask is ignored, all 0s are preserved in XGBoost.
*
* FIXME(trivialfis): Put above into document after we have a consistent way for
* processing input data.
@ -96,23 +98,23 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
CHECK_GT(n_columns, 0) << "Number of columns must not equal to 0.";
auto const& typestr = get<String const>(json_columns[0]["typestr"]);
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
std::vector<Columnar> columns;
CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
std::vector<ArrayInterface> columns;
std::vector<size_t> column_ptr({0});
auto first_column = Columnar(get<Object const>(json_columns[0]));
auto first_column = ArrayInterface(get<Object const>(json_columns[0]));
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
CHECK_NE(device_idx_, -1);
dh::safe_cuda(cudaSetDevice(device_idx_));
num_rows_ = first_column.size;
num_rows_ = first_column.num_rows;
for (auto& json_col : json_columns) {
auto column = Columnar(get<Object const>(json_col));
auto column = ArrayInterface(get<Object const>(json_col));
columns.push_back(column);
column_ptr.emplace_back(column_ptr.back() + column.size);
num_rows_ = std::max(num_rows_, size_t(column.size));
column_ptr.emplace_back(column_ptr.back() + column.num_rows);
num_rows_ = std::max(num_rows_, size_t(column.num_rows));
CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data))
<< "All columns should use the same device.";
CHECK_EQ(num_rows_, column.size)
CHECK_EQ(num_rows_, column.num_rows)
<< "All columns should have same number of rows.";
}
columns_ = columns;
@ -124,19 +126,65 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
size_t NumRows() const { return num_rows_; }
size_t NumColumns() const { return columns_.size(); }
size_t DeviceIdx()const {
return device_idx_;
}
size_t DeviceIdx() const { return device_idx_; }
// Cudf is column major
bool IsRowMajor() { return false; }
private:
CudfAdapterBatch batch;
dh::device_vector<Columnar> columns_;
dh::device_vector<ArrayInterface> columns_;
dh::device_vector<size_t> column_ptr_; // Exclusive scan of column sizes
size_t num_rows_{0};
int device_idx_;
};
class CupyAdapterBatch : public detail::NoMetaInfo {
public:
CupyAdapterBatch() = default;
CupyAdapterBatch(ArrayInterface array_interface)
: array_interface_(array_interface) {}
size_t Size() const {
return array_interface_.num_rows * array_interface_.num_cols;
}
__device__ COOTuple GetElement(size_t idx) const {
size_t column_idx = idx % array_interface_.num_cols;
size_t row_idx = idx / array_interface_.num_cols;
float value = array_interface_.valid.Data() == nullptr ||
array_interface_.valid.Check(row_idx)
? array_interface_.GetElement(idx)
: std::numeric_limits<float>::quiet_NaN();
return COOTuple(row_idx, column_idx, value);
}
private:
ArrayInterface array_interface_;
};
class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
public:
explicit CupyAdapter(std::string cuda_interface_str) {
Json json_array_interface =
Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()});
array_interface = ArrayInterface(get<Object const>(json_array_interface));
device_idx_ = dh::CudaGetPointerDevice(array_interface.data);
CHECK_NE(device_idx_, -1);
batch = CupyAdapterBatch(array_interface);
}
const CupyAdapterBatch& Value() const override { return batch; }
size_t NumRows() const { return array_interface.num_rows; }
size_t NumColumns() const { return array_interface.num_cols; }
size_t DeviceIdx() const { return device_idx_; }
bool IsRowMajor() { return true; }
private:
ArrayInterface array_interface;
CupyAdapterBatch batch;
int device_idx_;
};
}; // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_DEVICE_ADAPTER_H_

View File

@ -7,7 +7,6 @@
#include <xgboost/json.h>
#include "simple_csr_source.h"
#include "columnar.h"
namespace xgboost {
namespace data {

View File

@ -78,6 +78,35 @@ void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
}
}
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
explicit IsValidFunctor(float missing) : missing(missing) {}
float missing;
__device__ bool operator()(const Entry& x) const {
return IsValid(x.fvalue, missing);
}
};
// Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data
template <typename AdapterT>
void CopyDataRowMajor(AdapterT* adapter, common::Span<Entry> data,
int device_idx, float missing,
common::Span<size_t> row_ptr) {
auto& batch = adapter->Value();
auto transform_f = [=] __device__(size_t idx) {
const auto& e = batch.GetElement(idx);
return Entry(e.column_idx, e.value);
}; // NOLINT
auto counting = thrust::make_counting_iterator(0llu);
thrust::transform_iterator<decltype(transform_f), decltype(counting), Entry>
transform_iter(counting, transform_f);
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::copy_if(
thrust::cuda::par(alloc), transform_iter, transform_iter + batch.Size(),
thrust::device_pointer_cast(data.data()), IsValidFunctor(missing));
}
// Does not currently support metainfo as no on-device data source contains this
// Current implementation assumes a single batch. More batches can
// be supported in future. Does not currently support inferring row/column size
@ -102,11 +131,14 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
mat.info.num_nonzero_ = mat.page_.offset.HostVector().back();
mat.page_.data.Resize(mat.info.num_nonzero_);
if (adapter->IsRowMajor()) {
LOG(FATAL) << "Not implemented.";
CopyDataRowMajor(adapter, mat.page_.data.DeviceSpan(),
adapter->DeviceIdx(), missing, s_offset);
} else {
CopyDataColumnMajor(adapter, mat.page_.data.DeviceSpan(),
adapter->DeviceIdx(), missing, s_offset);
}
// Sync
mat.page_.data.HostVector();
mat.info.num_col_ = adapter->NumColumns();
mat.info.num_row_ = adapter->NumRows();
@ -116,5 +148,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(CupyAdapter* adapter, float missing,
int nthread);
} // namespace data
} // namespace xgboost

View File

@ -18,7 +18,7 @@ ENV PATH=/opt/python/bin:$PATH
# Create new Conda environment with cuDF and dask
RUN \
conda create -n cudf_test -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda \
cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask dask-cuda
cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask dask-cuda cupy
# Install other Python packages
RUN \

View File

@ -39,7 +39,7 @@ case "$suite" in
cudf)
source activate cudf_test
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py tests/python-gpu/test_from_cupy.py
;;
cpu)

View File

@ -8,7 +8,6 @@
#include "../../../src/common/bitfield.h"
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/data/simple_csr_source.h"
#include "../../../src/data/columnar.h"
namespace xgboost {
@ -62,4 +61,24 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows,
column["typestr"] = String(typestr);
return column;
}
template <typename T>
Json Generate2dArrayInterface(int rows, int cols, std::string typestr,
thrust::device_vector<T>* p_data) {
auto& data = *p_data;
thrust::sequence(data.begin(), data.end());
Json array_interface{Object()};
std::vector<Json> shape = {Json(static_cast<Integer::Int>(rows)),
Json(static_cast<Integer::Int>(cols))};
array_interface["shape"] = Array(shape);
std::vector<Json> j_data{
Json(Integer(reinterpret_cast<Integer::Int>(data.data().get()))),
Json(Boolean(false))};
array_interface["data"] = j_data;
array_interface["version"] = Integer(static_cast<Integer::Int>(1));
array_interface["typestr"] = String(typestr);
return array_interface;
}
} // namespace xgboost

View File

@ -7,7 +7,7 @@
#include "../helpers.h"
#include <thrust/device_vector.h>
#include "../../../src/data/device_adapter.cuh"
#include "test_columnar.h"
#include "test_array_interface.h"
using namespace xgboost; // NOLINT
void TestCudfAdapter()

View File

@ -9,8 +9,7 @@
namespace xgboost {
template <typename T>
std::string PrepareData(std::string typestr, thrust::device_vector<T>* out) {
constexpr size_t kRows = 16;
std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, const size_t kRows=16) {
out->resize(kRows);
auto& d_data = *out;
@ -66,7 +65,15 @@ TEST(MetaInfo, FromInterface) {
ASSERT_EQ(h_base_margin[i], d_data[i]);
}
EXPECT_ANY_THROW({info.SetInfo("group", str.c_str());});
thrust::device_vector<int> d_group_data;
std::string group_str = PrepareData<int>("<i4", &d_group_data, 4);
d_group_data[0] = 4;
d_group_data[1] = 3;
d_group_data[2] = 2;
d_group_data[3] = 1;
info.SetInfo("group", group_str.c_str());
std::vector<bst_group_t> expected_group_ptr = {0, 4, 7, 9, 10};
EXPECT_EQ(info.group_ptr_, expected_group_ptr);
}
TEST(MetaInfo, Group) {
@ -83,4 +90,4 @@ TEST(MetaInfo, Group) {
ASSERT_EQ(h_group[i], d_data[i-1] + h_group[i-1]) << "i: " << i;
}
}
} // namespace xgboost
} // namespace xgboost

View File

@ -6,7 +6,8 @@
#include <thrust/sequence.h>
#include "../../../src/data/device_adapter.cuh"
#include "../helpers.h"
#include "test_columnar.h"
#include "test_array_interface.h"
#include "../../../src/data/array_interface.h"
using namespace xgboost; // NOLINT
@ -316,3 +317,55 @@ TEST(SimpleDMatrix, FromColumnarSparseBasic) {
}
}
}
TEST(SimpleDMatrix, FromCupy){
int rows = 50;
int cols = 10;
thrust::device_vector< float> data(rows*cols);
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
std::stringstream ss;
Json::Dump(json_array_interface, &ss);
std::string str = ss.str();
data::CupyAdapter adapter(str);
data::SimpleDMatrix dmat(&adapter, -1, 1);
EXPECT_EQ(dmat.Info().num_col_, cols);
EXPECT_EQ(dmat.Info().num_row_, rows);
EXPECT_EQ(dmat.Info().num_nonzero_, rows*cols);
for (auto& batch : dmat.GetBatches<SparsePage>()) {
for (auto i = 0ull; i < batch.Size(); i++) {
auto inst = batch[i];
for (auto j = 0ull; j < inst.size(); j++) {
EXPECT_EQ(inst[j].fvalue, i * cols + j);
EXPECT_EQ(inst[j].index, j);
}
}
}
}
TEST(SimpleDMatrix, FromCupySparse){
int rows = 2;
int cols = 2;
thrust::device_vector< float> data(rows*cols);
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
data[1] = std::numeric_limits<float>::quiet_NaN();
data[2] = std::numeric_limits<float>::quiet_NaN();
std::stringstream ss;
Json::Dump(json_array_interface, &ss);
std::string str = ss.str();
data::CupyAdapter adapter(str);
data::SimpleDMatrix dmat(&adapter, -1, 1);
EXPECT_EQ(dmat.Info().num_col_, cols);
EXPECT_EQ(dmat.Info().num_row_, rows);
EXPECT_EQ(dmat.Info().num_nonzero_, rows * cols - 2);
auto& batch = *dmat.GetBatches<SparsePage>().begin();
auto inst0 = batch[0];
auto inst1 = batch[1];
EXPECT_EQ(batch[0].size(), 1);
EXPECT_EQ(batch[1].size(), 1);
EXPECT_EQ(batch[0][0].fvalue, 0.0f);
EXPECT_EQ(batch[0][0].index, 0);
EXPECT_EQ(batch[1][0].fvalue, 3.0f);
EXPECT_EQ(batch[1][0].index, 1);
}

View File

@ -2,6 +2,7 @@ import numpy as np
import xgboost as xgb
import sys
import pytest
sys.path.append("tests/python")
import testing as tm
@ -86,3 +87,64 @@ Arrow specification.'''
'x': cudf.Series([True, False, True, True, True])})
with pytest.raises(Exception):
dtrain = xgb.DMatrix(X_boolean, label=y_boolean)
@pytest.mark.skipif(**tm.no_cudf())
def test_cudf_training(self):
from cudf import DataFrame as df
import pandas as pd
X = pd.DataFrame(np.random.randn(50, 10))
y = pd.DataFrame(np.random.randn(50))
weights = np.random.random(50)
cudf_weights = df.from_pandas(pd.DataFrame(weights))
base_margin = np.random.random(50)
cudf_base_margin = df.from_pandas(pd.DataFrame(base_margin))
evals_result_cudf = {}
dtrain_cudf = xgb.DMatrix(df.from_pandas(X), df.from_pandas(y), weight=cudf_weights,
base_margin=cudf_base_margin)
xgb.train({'gpu_id': 0}, dtrain_cudf, evals=[(dtrain_cudf, "train")],
evals_result=evals_result_cudf)
evals_result_np = {}
dtrain_np = xgb.DMatrix(X, y, weight=weights, base_margin=base_margin)
xgb.train({}, dtrain_np, evals=[(dtrain_np, "train")],
evals_result=evals_result_np)
assert np.array_equal(evals_result_cudf["train"]["rmse"], evals_result_np["train"]["rmse"])
@pytest.mark.skipif(**tm.no_cudf())
def test_cudf_metainfo(self):
from cudf import DataFrame as df
import pandas as pd
n = 100
X = np.random.random((n, 2))
dmat_cudf = xgb.DMatrix(X)
dmat = xgb.DMatrix(X)
floats = np.random.random(n)
uints = np.array([4, 2, 8]).astype("uint32")
cudf_floats = df.from_pandas(pd.DataFrame(floats))
cudf_uints = df.from_pandas(pd.DataFrame(uints))
dmat.set_float_info('weight', floats)
dmat.set_float_info('label', floats)
dmat.set_float_info('base_margin', floats)
dmat.set_uint_info('group', uints)
dmat_cudf.set_interface_info('weight', cudf_floats)
dmat_cudf.set_interface_info('label', cudf_floats)
dmat_cudf.set_interface_info('base_margin', cudf_floats)
dmat_cudf.set_interface_info('group', cudf_uints)
# Test setting info with cudf DataFrame
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
assert np.array_equal(dmat.get_float_info('base_margin'),
dmat_cudf.get_float_info('base_margin'))
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
# Test setting info with cudf Series
dmat_cudf.set_interface_info('weight', cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_interface_info('label', cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_interface_info('base_margin', cudf_floats[cudf_floats.columns[0]])
dmat_cudf.set_interface_info('group', cudf_uints[cudf_uints.columns[0]])
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
assert np.array_equal(dmat.get_float_info('base_margin'),
dmat_cudf.get_float_info('base_margin'))
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))

View File

@ -0,0 +1,97 @@
import numpy as np
import xgboost as xgb
import sys
import pytest
sys.path.append("tests/python")
import testing as tm
def dmatrix_from_cupy(input_type, missing=np.NAN):
'''Test constructing DMatrix from cupy'''
import cupy as cp
kRows = 80
kCols = 3
np_X = np.random.randn(kRows, kCols).astype(dtype=input_type)
X = cp.array(np_X)
X[5, 0] = missing
X[3, 1] = missing
y = cp.random.randn(kRows).astype(dtype=input_type)
dtrain = xgb.DMatrix(X, missing=missing, label=y)
assert dtrain.num_col() == kCols
assert dtrain.num_row() == kRows
return dtrain
class TestFromArrayInterface:
'''Tests for constructing DMatrix from data structure conforming Apache
Arrow specification.'''
@pytest.mark.skipif(**tm.no_cupy())
def test_from_cupy(self):
'''Test constructing DMatrix from cupy'''
import cupy as cp
dmatrix_from_cupy(np.float32, np.NAN)
dmatrix_from_cupy(np.float64, np.NAN)
dmatrix_from_cupy(np.uint8, 2)
dmatrix_from_cupy(np.uint32, 3)
dmatrix_from_cupy(np.uint64, 4)
dmatrix_from_cupy(np.int8, 2)
dmatrix_from_cupy(np.int32, -2)
dmatrix_from_cupy(np.int64, -3)
with pytest.raises(Exception):
X = cp.random.randn(2, 2, dtype="float32")
dtrain = xgb.DMatrix(X, label=X)
@pytest.mark.skipif(**tm.no_cupy())
def test_cupy_training(self):
import cupy as cp
X = cp.random.randn(50, 10, dtype="float32")
y = cp.random.randn(50, dtype="float32")
weights = np.random.random(50)
cupy_weights = cp.array(weights)
base_margin = np.random.random(50)
cupy_base_margin = cp.array(base_margin)
evals_result_cupy = {}
dtrain_cp = xgb.DMatrix(X, y, weight=cupy_weights, base_margin=cupy_base_margin)
xgb.train({'gpu_id': 0}, dtrain_cp, evals=[(dtrain_cp, "train")],
evals_result=evals_result_cupy)
evals_result_np = {}
dtrain_np = xgb.DMatrix(cp.asnumpy(X), cp.asnumpy(y), weight=weights,
base_margin=base_margin)
xgb.train({'gpu_id': 0}, dtrain_np, evals=[(dtrain_np, "train")],
evals_result=evals_result_np)
assert np.array_equal(evals_result_cupy["train"]["rmse"], evals_result_np["train"]["rmse"])
@pytest.mark.skipif(**tm.no_cupy())
def test_cupy_metainfo(self):
import cupy as cp
n = 100
X = np.random.random((n, 2))
dmat_cupy = xgb.DMatrix(X)
dmat = xgb.DMatrix(X)
floats = np.random.random(n)
uints = np.array([4, 2, 8]).astype("uint32")
cupy_floats = cp.array(floats)
cupy_uints = cp.array(uints)
dmat.set_float_info('weight', floats)
dmat.set_float_info('label', floats)
dmat.set_float_info('base_margin', floats)
dmat.set_uint_info('group', uints)
dmat_cupy.set_interface_info('weight', cupy_floats)
dmat_cupy.set_interface_info('label', cupy_floats)
dmat_cupy.set_interface_info('base_margin', cupy_floats)
dmat_cupy.set_interface_info('group', cupy_uints)
# Test setting info with cupy
assert np.array_equal(dmat.get_float_info('weight'), dmat_cupy.get_float_info('weight'))
assert np.array_equal(dmat.get_float_info('label'), dmat_cupy.get_float_info('label'))
assert np.array_equal(dmat.get_float_info('base_margin'),
dmat_cupy.get_float_info('base_margin'))
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))

View File

@ -48,6 +48,15 @@ def no_cudf():
'reason': 'CUDF is not installed'}
def no_cupy():
reason = 'cupy is not installed.'
try:
import cupy as _ # noqa
return {'condition': False, 'reason': reason}
except ImportError:
return {'condition': True, 'reason': reason}
def no_dask_cudf():
reason = 'dask_cudf is not installed.'
try:

View File

@ -16,10 +16,9 @@ if [ ${TASK} == "python_test" ]; then
echo "-------------------------------"
conda activate python3
python --version
conda install numpy scipy pandas matplotlib scikit-learn
conda install numpy scipy pandas matplotlib scikit-learn dask
python -m pip install graphviz pytest pytest-cov codecov
python -m pip install dask distributed dask[dataframe]
python -m pip install datatable
python -m pytest -v --fulltrace -s tests/python --cov=python-package/xgboost || exit -1
codecov