Support dmatrix construction from cupy array (#5206)
This commit is contained in:
parent
2a071cebc5
commit
9c56480c61
@ -115,6 +115,7 @@ except ImportError:
|
||||
DT_INSTALLED = False
|
||||
|
||||
|
||||
# cudf
|
||||
try:
|
||||
from cudf import DataFrame as CUDF_DataFrame
|
||||
from cudf import Series as CUDF_Series
|
||||
|
||||
@ -24,7 +24,6 @@ from .compat import (
|
||||
os_fspath, os_PathLike)
|
||||
from .libpath import find_lib_path
|
||||
|
||||
|
||||
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
||||
c_bst_ulong = ctypes.c_uint64
|
||||
|
||||
@ -41,6 +40,7 @@ class EarlyStopException(Exception):
|
||||
best_iteration : int
|
||||
The best iteration stopped.
|
||||
"""
|
||||
|
||||
def __init__(self, best_iteration):
|
||||
super(EarlyStopException, self).__init__()
|
||||
self.best_iteration = best_iteration
|
||||
@ -231,50 +231,16 @@ def c_array(ctype, values):
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def _use_columnar_initializer(data):
|
||||
"""Whether should we use columnar format initializer (pass data in as json
|
||||
string). Currently cudf is the only valid option. For other dataframe
|
||||
types, use their sepcific API instead.
|
||||
"""
|
||||
if CUDF_INSTALLED and (isinstance(data, (CUDF_DataFrame, CUDF_Series))):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _extract_interface_from_cudf_series(data):
|
||||
"""This returns the array interface from the cudf series. This function
|
||||
should be upstreamed to cudf.
|
||||
"""
|
||||
interface = data.__cuda_array_interface__
|
||||
if data.has_null_mask:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
return interface
|
||||
|
||||
|
||||
def _extract_interface_from_cudf(df):
|
||||
"""This function should be upstreamed to cudf."""
|
||||
if not _use_columnar_initializer(df):
|
||||
raise ValueError('Only cudf is supported for initializing as json ' +
|
||||
'columnar format. For other libraries please ' +
|
||||
'refer to specific API.')
|
||||
|
||||
array_interfaces = []
|
||||
if isinstance(df, CUDF_DataFrame):
|
||||
for col in df.columns:
|
||||
array_interfaces.append(
|
||||
_extract_interface_from_cudf_series(df[col]))
|
||||
else:
|
||||
array_interfaces.append(_extract_interface_from_cudf_series(df))
|
||||
|
||||
interfaces = bytes(json.dumps(array_interfaces, indent=2), 'utf-8')
|
||||
return interfaces
|
||||
|
||||
|
||||
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
|
||||
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
|
||||
'float16': 'float', 'float32': 'float', 'float64': 'float',
|
||||
'bool': 'i'}
|
||||
|
||||
# Either object has cuda array interface or contains columns with interfaces
|
||||
def _has_cuda_array_interface(data):
|
||||
return hasattr(data, '__cuda_array_interface__') or (
|
||||
CUDF_INSTALLED and isinstance(data, CUDF_DataFrame))
|
||||
|
||||
|
||||
def _maybe_pandas_data(data, feature_names, feature_types):
|
||||
"""Extract internal data from pd.DataFrame for DMatrix data"""
|
||||
@ -433,7 +399,7 @@ class DMatrix(object):
|
||||
"""Parameters
|
||||
----------
|
||||
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
|
||||
dt.Frame/cudf.DataFrame
|
||||
dt.Frame/cudf.DataFrame/cupy.array
|
||||
Data source of DMatrix.
|
||||
When data is string or os.PathLike type, it represents the path
|
||||
libsvm format txt file, csv file (by specifying uri parameter
|
||||
@ -500,8 +466,10 @@ class DMatrix(object):
|
||||
self._init_from_npy2d(data, missing, nthread)
|
||||
elif isinstance(data, DataTable):
|
||||
self._init_from_dt(data, nthread)
|
||||
elif _use_columnar_initializer(data):
|
||||
self._init_from_columnar(data, missing)
|
||||
elif hasattr(data, "__cuda_array_interface__"):
|
||||
self._init_from_array_interface(data, missing, nthread)
|
||||
elif CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
|
||||
self._init_from_array_interface_columns(data, missing, nthread)
|
||||
else:
|
||||
try:
|
||||
csr = scipy.sparse.csr_matrix(data)
|
||||
@ -596,7 +564,8 @@ class DMatrix(object):
|
||||
ptrs[icol] = ctypes.c_void_p(ptr)
|
||||
else:
|
||||
# datatable<=0.8.0
|
||||
from datatable.internal import frame_column_data_r # pylint: disable=no-name-in-module,import-error
|
||||
from datatable.internal import \
|
||||
frame_column_data_r # pylint: disable=no-name-in-module,import-error
|
||||
for icol in range(data.ncols):
|
||||
ptrs[icol] = frame_column_data_r(data, icol)
|
||||
|
||||
@ -614,16 +583,38 @@ class DMatrix(object):
|
||||
nthread))
|
||||
self.handle = handle
|
||||
|
||||
def _init_from_columnar(self, df, missing):
|
||||
def _init_from_array_interface_columns(self, df, missing, nthread):
|
||||
"""Initialize DMatrix from columnar memory format."""
|
||||
interfaces = _extract_interface_from_cudf(df)
|
||||
interfaces = []
|
||||
for col in df:
|
||||
interface = df[col].__cuda_array_interface__
|
||||
if 'mask' in interface:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
interfaces.append(interface)
|
||||
handle = ctypes.c_void_p()
|
||||
has_missing = missing is not None
|
||||
missing = missing if has_missing else np.nan
|
||||
missing = missing if missing is not None else np.nan
|
||||
nthread = nthread if nthread is not None else 1
|
||||
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromArrayInterfaces(
|
||||
interfaces, ctypes.c_int32(has_missing),
|
||||
ctypes.c_float(missing), ctypes.byref(handle)))
|
||||
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
interfaces_str,
|
||||
ctypes.c_float(missing), ctypes.c_int(nthread), ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
def _init_from_array_interface(self, data, missing, nthread):
|
||||
"""Initialize DMatrix from cupy ndarray."""
|
||||
interface = data.__cuda_array_interface__
|
||||
if 'mask' in interface:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||
|
||||
handle = ctypes.c_void_p()
|
||||
missing = missing if missing is not None else np.nan
|
||||
nthread = nthread if nthread is not None else 1
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromArrayInterface(
|
||||
interface_str,
|
||||
ctypes.c_float(missing), ctypes.c_int(nthread), ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
def __del__(self):
|
||||
@ -694,11 +685,18 @@ class DMatrix(object):
|
||||
c_bst_ulong(len(data))))
|
||||
|
||||
def set_interface_info(self, field, data):
|
||||
"""Set info type peoperty into DMatrix."""
|
||||
interfaces = _extract_interface_from_cudf(data)
|
||||
"""Set info type property into DMatrix."""
|
||||
|
||||
# If we are passed a dataframe, extract the series
|
||||
if isinstance(data, CUDF_DataFrame):
|
||||
if len(data.columns) != 1:
|
||||
raise ValueError('Expecting meta-info to contain a single column')
|
||||
data = data[data.columns[0]]
|
||||
|
||||
interface = bytes(json.dumps([data.__cuda_array_interface__], indent=2), 'utf-8')
|
||||
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
|
||||
c_str(field),
|
||||
interfaces))
|
||||
interface))
|
||||
|
||||
def set_float_info_npy2d(self, field, data):
|
||||
"""Set float type property into the DMatrix
|
||||
@ -779,7 +777,7 @@ class DMatrix(object):
|
||||
"""
|
||||
if isinstance(label, np.ndarray):
|
||||
self.set_label_npy2d(label)
|
||||
elif _use_columnar_initializer(label):
|
||||
elif _has_cuda_array_interface(label):
|
||||
self.set_interface_info('label', label)
|
||||
else:
|
||||
self.set_float_info('label', label)
|
||||
@ -812,7 +810,7 @@ class DMatrix(object):
|
||||
"""
|
||||
if isinstance(weight, np.ndarray):
|
||||
self.set_weight_npy2d(weight)
|
||||
elif _use_columnar_initializer(weight):
|
||||
elif _has_cuda_array_interface(weight):
|
||||
self.set_interface_info('weight', weight)
|
||||
else:
|
||||
self.set_float_info('weight', weight)
|
||||
@ -849,7 +847,7 @@ class DMatrix(object):
|
||||
margin: array like
|
||||
Prediction margin of each datapoint
|
||||
"""
|
||||
if _use_columnar_initializer(margin):
|
||||
if _has_cuda_array_interface(margin):
|
||||
self.set_interface_info('base_margin', margin)
|
||||
else:
|
||||
self.set_float_info('base_margin', margin)
|
||||
@ -862,7 +860,7 @@ class DMatrix(object):
|
||||
group : array like
|
||||
Group size of each group
|
||||
"""
|
||||
if _use_columnar_initializer(group):
|
||||
if _has_cuda_array_interface(group):
|
||||
self.set_interface_info('group', group)
|
||||
else:
|
||||
self.set_uint_info('group', group)
|
||||
@ -1073,7 +1071,7 @@ class Booster(object):
|
||||
ctypes.byref(self.handle)))
|
||||
|
||||
if isinstance(params, dict) and \
|
||||
'validate_parameters' not in params.keys():
|
||||
'validate_parameters' not in params.keys():
|
||||
params['validate_parameters'] = 1
|
||||
self.set_param(params or {})
|
||||
if (params is not None) and ('booster' in params):
|
||||
|
||||
@ -206,12 +206,24 @@ int XGDMatrixCreateFromDataIter(
|
||||
}
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
XGB_DLL int XGDMatrixCreateFromArrayInterfaces(
|
||||
char const* c_json_strs, bst_int has_missing, bst_float missing, DMatrixHandle* out) {
|
||||
XGB_DLL int XGDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
|
||||
bst_float missing,
|
||||
int nthread,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
LOG(FATAL) << "Xgboost not compiled with cuda";
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
bst_float missing,
|
||||
int nthread,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
LOG(FATAL) << "Xgboost not compiled with cuda";
|
||||
API_END();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
|
||||
|
||||
@ -7,14 +7,27 @@
|
||||
#include "../data/device_adapter.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
XGB_DLL int XGDMatrixCreateFromArrayInterfaces(char const* c_json_strs,
|
||||
bst_int has_missing,
|
||||
bst_float missing,
|
||||
DMatrixHandle* out) {
|
||||
XGB_DLL int XGDMatrixCreateFromArrayInterfaceColumns(char const* c_json_strs,
|
||||
bst_float missing,
|
||||
int nthread,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
std::string json_str{c_json_strs};
|
||||
data::CudfAdapter adapter(json_str);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, 1));
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
bst_float missing, int nthread,
|
||||
DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
std::string json_str{c_json_strs};
|
||||
data::CupyAdapter adapter(json_str);
|
||||
*out =
|
||||
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||
API_END();
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
/*!
|
||||
* Copyright 2019 by Contributors
|
||||
* \file columnar.h
|
||||
* \file array_interface.h
|
||||
* \brief Basic structure holding a reference to arrow columnar data format.
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_COLUMNAR_H_
|
||||
#define XGBOOST_DATA_COLUMNAR_H_
|
||||
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
|
||||
#include <cinttypes>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
@ -18,7 +19,7 @@
|
||||
|
||||
namespace xgboost {
|
||||
// Common errors in parsing columnar format.
|
||||
struct ColumnarErrors {
|
||||
struct ArrayInterfaceErrors {
|
||||
static char const* Contigious() {
|
||||
return "Memory should be contigious.";
|
||||
}
|
||||
@ -119,15 +120,12 @@ class ArrayInterfaceHandler {
|
||||
if (array.find("version") == array.cend()) {
|
||||
LOG(FATAL) << "Missing `version' field for array interface";
|
||||
}
|
||||
auto version = get<Integer const>(array.at("version"));
|
||||
CHECK_EQ(version, 1) << ColumnarErrors::Version();
|
||||
|
||||
if (array.find("typestr") == array.cend()) {
|
||||
LOG(FATAL) << "Missing `typestr' field for array interface";
|
||||
}
|
||||
auto typestr = get<String const>(array.at("typestr"));
|
||||
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
|
||||
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
|
||||
CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat();
|
||||
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
|
||||
|
||||
if (array.find("shape") == array.cend()) {
|
||||
LOG(FATAL) << "Missing `shape' field for array interface";
|
||||
@ -149,7 +147,7 @@ class ArrayInterfaceHandler {
|
||||
auto p_mask = GetPtrFromArrayData<RBitField8::value_type*>(j_mask);
|
||||
|
||||
auto j_shape = get<Array const>(j_mask.at("shape"));
|
||||
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
CHECK_EQ(j_shape.size(), 1) << ArrayInterfaceErrors::Dimension(1);
|
||||
auto typestr = get<String const>(j_mask.at("typestr"));
|
||||
// For now this is just 1, we can support different size of interger in mask.
|
||||
int64_t const type_length = typestr.at(2) - 48;
|
||||
@ -178,8 +176,8 @@ class ArrayInterfaceHandler {
|
||||
|
||||
if (j_mask.find("strides") != j_mask.cend()) {
|
||||
auto strides = get<Array const>(column.at("strides"));
|
||||
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ColumnarErrors::Contigious();
|
||||
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contigious();
|
||||
}
|
||||
|
||||
s_mask = {p_mask, span_size};
|
||||
@ -188,18 +186,28 @@ class ArrayInterfaceHandler {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static size_t ExtractLength(std::map<std::string, Json> const& column) {
|
||||
static std::pair<size_t, size_t> ExtractShape(
|
||||
std::map<std::string, Json> const& column) {
|
||||
auto j_shape = get<Array const>(column.at("shape"));
|
||||
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
if (column.find("strides") != column.cend()) {
|
||||
auto strides = get<Array const>(column.at("strides"));
|
||||
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), typestr.at(2) - '0')
|
||||
<< ColumnarErrors::Contigious();
|
||||
if (!IsA<Null>(column.at("strides"))) {
|
||||
auto strides = get<Array const>(column.at("strides"));
|
||||
CHECK_EQ(strides.size(), j_shape.size())
|
||||
<< ArrayInterfaceErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), typestr.at(2) - '0')
|
||||
<< ArrayInterfaceErrors::Contigious();
|
||||
}
|
||||
}
|
||||
|
||||
return static_cast<size_t>(get<Integer const>(j_shape.at(0)));
|
||||
if (j_shape.size() == 1) {
|
||||
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))), 1};
|
||||
} else {
|
||||
CHECK_EQ(j_shape.size(), 2)
|
||||
<< "Only 1D or 2-D arrays currently supported.";
|
||||
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))),
|
||||
static_cast<size_t>(get<Integer const>(j_shape.at(1)))};
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
|
||||
@ -212,25 +220,27 @@ class ArrayInterfaceHandler {
|
||||
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
||||
|
||||
|
||||
auto length = ExtractLength(column);
|
||||
auto shape = ExtractShape(column);
|
||||
|
||||
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
||||
return common::Span<T>{p_data, length};
|
||||
return common::Span<T>{p_data, shape.first * shape.second};
|
||||
}
|
||||
};
|
||||
|
||||
// A view over __array_interface__
|
||||
class Columnar {
|
||||
class ArrayInterface {
|
||||
using mask_type = unsigned char;
|
||||
using index_type = int32_t;
|
||||
|
||||
public:
|
||||
Columnar() = default;
|
||||
explicit Columnar(std::map<std::string, Json> const& column) {
|
||||
ArrayInterface() = default;
|
||||
explicit ArrayInterface(std::map<std::string, Json> const& column) {
|
||||
ArrayInterfaceHandler::Validate(column);
|
||||
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
||||
CHECK(data) << "Column is null";
|
||||
size = ArrayInterfaceHandler::ExtractLength(column);
|
||||
auto shape = ArrayInterfaceHandler::ExtractShape(column);
|
||||
num_rows = shape.first;
|
||||
num_cols = shape.second;
|
||||
|
||||
common::Span<RBitField8::value_type> s_mask;
|
||||
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
||||
@ -238,7 +248,7 @@ class Columnar {
|
||||
valid = RBitField8(s_mask);
|
||||
|
||||
if (s_mask.data()) {
|
||||
CHECK_EQ(n_bits, size)
|
||||
CHECK_EQ(n_bits, num_rows)
|
||||
<< "Shape of bit mask doesn't match data shape. "
|
||||
<< "XGBoost doesn't support internal broadcasting.";
|
||||
}
|
||||
@ -271,7 +281,7 @@ class Columnar {
|
||||
} else if (type[1] == 'u' && type[2] == '8') {
|
||||
return;
|
||||
} else {
|
||||
LOG(FATAL) << ColumnarErrors::UnSupportedType(type);
|
||||
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(type);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -304,10 +314,11 @@ class Columnar {
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
int32_t size;
|
||||
int32_t num_rows;
|
||||
int32_t num_cols;
|
||||
void* data;
|
||||
char type[3];
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_COLUMNAR_H_
|
||||
#endif // XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
@ -7,14 +7,14 @@
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "columnar.h"
|
||||
#include "array_interface.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "device_adapter.cuh"
|
||||
#include "simple_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<float>* out) {
|
||||
void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {
|
||||
auto SetDeviceToPtr = [](void* ptr) {
|
||||
cudaPointerAttributes attr;
|
||||
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
|
||||
@ -22,43 +22,42 @@ void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<fl
|
||||
dh::safe_cuda(cudaSetDevice(ptr_device));
|
||||
return ptr_device;
|
||||
};
|
||||
Columnar foreign_column(column);
|
||||
auto ptr_device = SetDeviceToPtr(foreign_column.data);
|
||||
auto ptr_device = SetDeviceToPtr(column.data);
|
||||
|
||||
out->SetDevice(ptr_device);
|
||||
out->Resize(foreign_column.size);
|
||||
out->Resize(column.num_rows);
|
||||
|
||||
auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
|
||||
|
||||
dh::LaunchN(ptr_device, foreign_column.size, [=] __device__(size_t idx) {
|
||||
p_dst[idx] = foreign_column.GetElement(idx);
|
||||
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
|
||||
p_dst[idx] = column.GetElement(idx);
|
||||
});
|
||||
}
|
||||
|
||||
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
|
||||
auto const& j_arr = get<Array>(j_interface);
|
||||
CHECK_EQ(j_arr.size(), 1) << "MetaInfo: " << c_key << ". " << ColumnarErrors::Dimension(1);;
|
||||
auto const& j_arr_obj = get<Object const>(j_arr[0]);
|
||||
std::string key {c_key};
|
||||
ArrayInterfaceHandler::Validate(j_arr_obj);
|
||||
if (j_arr_obj.find("mask") != j_arr_obj.cend()) {
|
||||
LOG(FATAL) << "Meta info " << key << " should be dense, found validity mask";
|
||||
}
|
||||
auto const& typestr = get<String const>(j_arr_obj.at("typestr"));
|
||||
CHECK_EQ(j_arr.size(), 1)
|
||||
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
|
||||
ArrayInterface array_interface(get<Object const>(j_arr[0]));
|
||||
std::string key{c_key};
|
||||
CHECK(!array_interface.valid.Data())
|
||||
<< "Meta info " << key << " should be dense, found validity mask";
|
||||
CHECK_EQ(array_interface.num_cols, 1)
|
||||
<< "Meta info should be a single column.";
|
||||
|
||||
if (key == "label") {
|
||||
CopyInfoImpl(j_arr_obj, &labels_);
|
||||
CopyInfoImpl(array_interface, &labels_);
|
||||
} else if (key == "weight") {
|
||||
CopyInfoImpl(j_arr_obj, &weights_);
|
||||
CopyInfoImpl(array_interface, &weights_);
|
||||
} else if (key == "base_margin") {
|
||||
CopyInfoImpl(j_arr_obj, &base_margin_);
|
||||
CopyInfoImpl(array_interface, &base_margin_);
|
||||
} else if (key == "group") {
|
||||
// Ranking is not performed on device.
|
||||
auto s_data = ArrayInterfaceHandler::ExtractData<uint32_t>(j_arr_obj);
|
||||
thrust::device_ptr<uint32_t> p_src {s_data.data()};
|
||||
thrust::device_ptr<uint32_t> p_src{
|
||||
reinterpret_cast<uint32_t*>(array_interface.data)};
|
||||
|
||||
auto length = s_data.size();
|
||||
auto length = array_interface.num_rows;
|
||||
group_ptr_.resize(length + 1);
|
||||
group_ptr_[0] = 0;
|
||||
thrust::copy(p_src, p_src + length, group_ptr_.begin() + 1);
|
||||
@ -82,4 +81,7 @@ DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
|
||||
template DMatrix* DMatrix::Create<data::CudfAdapter>(
|
||||
data::CudfAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
template DMatrix* DMatrix::Create<data::CupyAdapter>(
|
||||
data::CupyAdapter* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size);
|
||||
} // namespace xgboost
|
||||
|
||||
@ -7,9 +7,9 @@
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "columnar.h"
|
||||
#include "adapter.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "adapter.h"
|
||||
#include "array_interface.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
@ -17,12 +17,13 @@ namespace data {
|
||||
class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
CudfAdapterBatch() = default;
|
||||
CudfAdapterBatch(common::Span<Columnar> columns,
|
||||
CudfAdapterBatch(common::Span<ArrayInterface> columns,
|
||||
common::Span<size_t> column_ptr, size_t num_elements)
|
||||
: columns_(columns),column_ptr_(column_ptr), num_elements(num_elements) {}
|
||||
size_t Size()const { return num_elements; }
|
||||
__device__ COOTuple GetElement(size_t idx)const
|
||||
{
|
||||
: columns_(columns),
|
||||
column_ptr_(column_ptr),
|
||||
num_elements(num_elements) {}
|
||||
size_t Size() const { return num_elements; }
|
||||
__device__ COOTuple GetElement(size_t idx) const {
|
||||
size_t column_idx =
|
||||
dh::UpperBound(column_ptr_.data(), column_ptr_.size(), idx) - 1;
|
||||
auto& column = columns_[column_idx];
|
||||
@ -34,22 +35,23 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
|
||||
}
|
||||
|
||||
private:
|
||||
common::Span<Columnar> columns_;
|
||||
common::Span<ArrayInterface> columns_;
|
||||
common::Span<size_t> column_ptr_;
|
||||
size_t num_elements;
|
||||
};
|
||||
|
||||
/*!
|
||||
* Please be careful that, in official specification, the only three required fields are
|
||||
* `shape', `version' and `typestr'. Any other is optional, including `data'. But here
|
||||
* we have one additional requirements for input data:
|
||||
* Please be careful that, in official specification, the only three required
|
||||
* fields are `shape', `version' and `typestr'. Any other is optional,
|
||||
* including `data'. But here we have one additional requirements for input
|
||||
* data:
|
||||
*
|
||||
* - `data' field is required, passing in an empty dataset is not accepted, as most (if
|
||||
* not all) of our algorithms don't have test for empty dataset. An error is better
|
||||
* than a crash.
|
||||
* - `data' field is required, passing in an empty dataset is not accepted, as
|
||||
* most (if not all) of our algorithms don't have test for empty dataset. An
|
||||
* error is better than a crash.
|
||||
*
|
||||
* What if invalid value from dataframe is 0 but I specify missing=NaN in XGBoost? Since
|
||||
* validity mask is ignored, all 0s are preserved in XGBoost.
|
||||
* What if invalid value from dataframe is 0 but I specify missing=NaN in
|
||||
* XGBoost? Since validity mask is ignored, all 0s are preserved in XGBoost.
|
||||
*
|
||||
* FIXME(trivialfis): Put above into document after we have a consistent way for
|
||||
* processing input data.
|
||||
@ -96,23 +98,23 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
||||
CHECK_GT(n_columns, 0) << "Number of columns must not equal to 0.";
|
||||
|
||||
auto const& typestr = get<String const>(json_columns[0]["typestr"]);
|
||||
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
|
||||
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
|
||||
std::vector<Columnar> columns;
|
||||
CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat();
|
||||
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
|
||||
std::vector<ArrayInterface> columns;
|
||||
std::vector<size_t> column_ptr({0});
|
||||
auto first_column = Columnar(get<Object const>(json_columns[0]));
|
||||
auto first_column = ArrayInterface(get<Object const>(json_columns[0]));
|
||||
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
|
||||
CHECK_NE(device_idx_, -1);
|
||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||
num_rows_ = first_column.size;
|
||||
num_rows_ = first_column.num_rows;
|
||||
for (auto& json_col : json_columns) {
|
||||
auto column = Columnar(get<Object const>(json_col));
|
||||
auto column = ArrayInterface(get<Object const>(json_col));
|
||||
columns.push_back(column);
|
||||
column_ptr.emplace_back(column_ptr.back() + column.size);
|
||||
num_rows_ = std::max(num_rows_, size_t(column.size));
|
||||
column_ptr.emplace_back(column_ptr.back() + column.num_rows);
|
||||
num_rows_ = std::max(num_rows_, size_t(column.num_rows));
|
||||
CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data))
|
||||
<< "All columns should use the same device.";
|
||||
CHECK_EQ(num_rows_, column.size)
|
||||
CHECK_EQ(num_rows_, column.num_rows)
|
||||
<< "All columns should have same number of rows.";
|
||||
}
|
||||
columns_ = columns;
|
||||
@ -124,19 +126,65 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
||||
|
||||
size_t NumRows() const { return num_rows_; }
|
||||
size_t NumColumns() const { return columns_.size(); }
|
||||
size_t DeviceIdx()const {
|
||||
return device_idx_;
|
||||
}
|
||||
size_t DeviceIdx() const { return device_idx_; }
|
||||
|
||||
// Cudf is column major
|
||||
bool IsRowMajor() { return false; }
|
||||
|
||||
private:
|
||||
CudfAdapterBatch batch;
|
||||
dh::device_vector<Columnar> columns_;
|
||||
dh::device_vector<ArrayInterface> columns_;
|
||||
dh::device_vector<size_t> column_ptr_; // Exclusive scan of column sizes
|
||||
size_t num_rows_{0};
|
||||
int device_idx_;
|
||||
};
|
||||
|
||||
class CupyAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
CupyAdapterBatch() = default;
|
||||
CupyAdapterBatch(ArrayInterface array_interface)
|
||||
: array_interface_(array_interface) {}
|
||||
size_t Size() const {
|
||||
return array_interface_.num_rows * array_interface_.num_cols;
|
||||
}
|
||||
__device__ COOTuple GetElement(size_t idx) const {
|
||||
size_t column_idx = idx % array_interface_.num_cols;
|
||||
size_t row_idx = idx / array_interface_.num_cols;
|
||||
float value = array_interface_.valid.Data() == nullptr ||
|
||||
array_interface_.valid.Check(row_idx)
|
||||
? array_interface_.GetElement(idx)
|
||||
: std::numeric_limits<float>::quiet_NaN();
|
||||
return COOTuple(row_idx, column_idx, value);
|
||||
}
|
||||
|
||||
private:
|
||||
ArrayInterface array_interface_;
|
||||
};
|
||||
|
||||
class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
||||
public:
|
||||
explicit CupyAdapter(std::string cuda_interface_str) {
|
||||
Json json_array_interface =
|
||||
Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()});
|
||||
array_interface = ArrayInterface(get<Object const>(json_array_interface));
|
||||
device_idx_ = dh::CudaGetPointerDevice(array_interface.data);
|
||||
CHECK_NE(device_idx_, -1);
|
||||
batch = CupyAdapterBatch(array_interface);
|
||||
}
|
||||
const CupyAdapterBatch& Value() const override { return batch; }
|
||||
|
||||
size_t NumRows() const { return array_interface.num_rows; }
|
||||
size_t NumColumns() const { return array_interface.num_cols; }
|
||||
size_t DeviceIdx() const { return device_idx_; }
|
||||
|
||||
bool IsRowMajor() { return true; }
|
||||
|
||||
private:
|
||||
ArrayInterface array_interface;
|
||||
CupyAdapterBatch batch;
|
||||
int device_idx_;
|
||||
};
|
||||
|
||||
}; // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_DEVICE_ADAPTER_H_
|
||||
|
||||
@ -7,7 +7,6 @@
|
||||
#include <xgboost/json.h>
|
||||
|
||||
#include "simple_csr_source.h"
|
||||
#include "columnar.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
@ -78,6 +78,35 @@ void CopyDataColumnMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
}
|
||||
}
|
||||
|
||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||
|
||||
float missing;
|
||||
__device__ bool operator()(const Entry& x) const {
|
||||
return IsValid(x.fvalue, missing);
|
||||
}
|
||||
};
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
template <typename AdapterT>
|
||||
void CopyDataRowMajor(AdapterT* adapter, common::Span<Entry> data,
|
||||
int device_idx, float missing,
|
||||
common::Span<size_t> row_ptr) {
|
||||
auto& batch = adapter->Value();
|
||||
auto transform_f = [=] __device__(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
return Entry(e.column_idx, e.value);
|
||||
}; // NOLINT
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
thrust::transform_iterator<decltype(transform_f), decltype(counting), Entry>
|
||||
transform_iter(counting, transform_f);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::copy_if(
|
||||
thrust::cuda::par(alloc), transform_iter, transform_iter + batch.Size(),
|
||||
thrust::device_pointer_cast(data.data()), IsValidFunctor(missing));
|
||||
}
|
||||
|
||||
// Does not currently support metainfo as no on-device data source contains this
|
||||
// Current implementation assumes a single batch. More batches can
|
||||
// be supported in future. Does not currently support inferring row/column size
|
||||
@ -102,11 +131,14 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
mat.info.num_nonzero_ = mat.page_.offset.HostVector().back();
|
||||
mat.page_.data.Resize(mat.info.num_nonzero_);
|
||||
if (adapter->IsRowMajor()) {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
CopyDataRowMajor(adapter, mat.page_.data.DeviceSpan(),
|
||||
adapter->DeviceIdx(), missing, s_offset);
|
||||
} else {
|
||||
CopyDataColumnMajor(adapter, mat.page_.data.DeviceSpan(),
|
||||
adapter->DeviceIdx(), missing, s_offset);
|
||||
}
|
||||
// Sync
|
||||
mat.page_.data.HostVector();
|
||||
|
||||
mat.info.num_col_ = adapter->NumColumns();
|
||||
mat.info.num_row_ = adapter->NumRows();
|
||||
@ -116,5 +148,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
|
||||
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
template SimpleDMatrix::SimpleDMatrix(CupyAdapter* adapter, float missing,
|
||||
int nthread);
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@ -18,7 +18,7 @@ ENV PATH=/opt/python/bin:$PATH
|
||||
# Create new Conda environment with cuDF and dask
|
||||
RUN \
|
||||
conda create -n cudf_test -c rapidsai -c nvidia -c numba -c conda-forge -c anaconda \
|
||||
cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask dask-cuda
|
||||
cudf=0.9 python=3.7 anaconda::cudatoolkit=$CUDA_VERSION dask dask-cuda cupy
|
||||
|
||||
# Install other Python packages
|
||||
RUN \
|
||||
|
||||
@ -39,7 +39,7 @@ case "$suite" in
|
||||
|
||||
cudf)
|
||||
source activate cudf_test
|
||||
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py
|
||||
pytest -v -s --fulltrace -m "not mgpu" tests/python-gpu/test_from_columnar.py tests/python-gpu/test_from_cupy.py
|
||||
;;
|
||||
|
||||
cpu)
|
||||
|
||||
@ -8,7 +8,6 @@
|
||||
#include "../../../src/common/bitfield.h"
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/data/simple_csr_source.h"
|
||||
#include "../../../src/data/columnar.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -62,4 +61,24 @@ Json GenerateSparseColumn(std::string const& typestr, size_t kRows,
|
||||
column["typestr"] = String(typestr);
|
||||
return column;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Json Generate2dArrayInterface(int rows, int cols, std::string typestr,
|
||||
thrust::device_vector<T>* p_data) {
|
||||
auto& data = *p_data;
|
||||
thrust::sequence(data.begin(), data.end());
|
||||
|
||||
Json array_interface{Object()};
|
||||
std::vector<Json> shape = {Json(static_cast<Integer::Int>(rows)),
|
||||
Json(static_cast<Integer::Int>(cols))};
|
||||
array_interface["shape"] = Array(shape);
|
||||
std::vector<Json> j_data{
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(data.data().get()))),
|
||||
Json(Boolean(false))};
|
||||
array_interface["data"] = j_data;
|
||||
array_interface["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
array_interface["typestr"] = String(typestr);
|
||||
return array_interface;
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
@ -7,7 +7,7 @@
|
||||
#include "../helpers.h"
|
||||
#include <thrust/device_vector.h>
|
||||
#include "../../../src/data/device_adapter.cuh"
|
||||
#include "test_columnar.h"
|
||||
#include "test_array_interface.h"
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
void TestCudfAdapter()
|
||||
|
||||
@ -9,8 +9,7 @@
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T>
|
||||
std::string PrepareData(std::string typestr, thrust::device_vector<T>* out) {
|
||||
constexpr size_t kRows = 16;
|
||||
std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, const size_t kRows=16) {
|
||||
out->resize(kRows);
|
||||
auto& d_data = *out;
|
||||
|
||||
@ -66,7 +65,15 @@ TEST(MetaInfo, FromInterface) {
|
||||
ASSERT_EQ(h_base_margin[i], d_data[i]);
|
||||
}
|
||||
|
||||
EXPECT_ANY_THROW({info.SetInfo("group", str.c_str());});
|
||||
thrust::device_vector<int> d_group_data;
|
||||
std::string group_str = PrepareData<int>("<i4", &d_group_data, 4);
|
||||
d_group_data[0] = 4;
|
||||
d_group_data[1] = 3;
|
||||
d_group_data[2] = 2;
|
||||
d_group_data[3] = 1;
|
||||
info.SetInfo("group", group_str.c_str());
|
||||
std::vector<bst_group_t> expected_group_ptr = {0, 4, 7, 9, 10};
|
||||
EXPECT_EQ(info.group_ptr_, expected_group_ptr);
|
||||
}
|
||||
|
||||
TEST(MetaInfo, Group) {
|
||||
@ -83,4 +90,4 @@ TEST(MetaInfo, Group) {
|
||||
ASSERT_EQ(h_group[i], d_data[i-1] + h_group[i-1]) << "i: " << i;
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost
|
||||
|
||||
@ -6,7 +6,8 @@
|
||||
#include <thrust/sequence.h>
|
||||
#include "../../../src/data/device_adapter.cuh"
|
||||
#include "../helpers.h"
|
||||
#include "test_columnar.h"
|
||||
#include "test_array_interface.h"
|
||||
#include "../../../src/data/array_interface.h"
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
@ -316,3 +317,55 @@ TEST(SimpleDMatrix, FromColumnarSparseBasic) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(SimpleDMatrix, FromCupy){
|
||||
int rows = 50;
|
||||
int cols = 10;
|
||||
thrust::device_vector< float> data(rows*cols);
|
||||
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
|
||||
std::stringstream ss;
|
||||
Json::Dump(json_array_interface, &ss);
|
||||
std::string str = ss.str();
|
||||
data::CupyAdapter adapter(str);
|
||||
data::SimpleDMatrix dmat(&adapter, -1, 1);
|
||||
EXPECT_EQ(dmat.Info().num_col_, cols);
|
||||
EXPECT_EQ(dmat.Info().num_row_, rows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, rows*cols);
|
||||
|
||||
for (auto& batch : dmat.GetBatches<SparsePage>()) {
|
||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||
auto inst = batch[i];
|
||||
for (auto j = 0ull; j < inst.size(); j++) {
|
||||
EXPECT_EQ(inst[j].fvalue, i * cols + j);
|
||||
EXPECT_EQ(inst[j].index, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SimpleDMatrix, FromCupySparse){
|
||||
int rows = 2;
|
||||
int cols = 2;
|
||||
thrust::device_vector< float> data(rows*cols);
|
||||
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
|
||||
data[1] = std::numeric_limits<float>::quiet_NaN();
|
||||
data[2] = std::numeric_limits<float>::quiet_NaN();
|
||||
std::stringstream ss;
|
||||
Json::Dump(json_array_interface, &ss);
|
||||
std::string str = ss.str();
|
||||
data::CupyAdapter adapter(str);
|
||||
data::SimpleDMatrix dmat(&adapter, -1, 1);
|
||||
EXPECT_EQ(dmat.Info().num_col_, cols);
|
||||
EXPECT_EQ(dmat.Info().num_row_, rows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, rows * cols - 2);
|
||||
auto& batch = *dmat.GetBatches<SparsePage>().begin();
|
||||
auto inst0 = batch[0];
|
||||
auto inst1 = batch[1];
|
||||
EXPECT_EQ(batch[0].size(), 1);
|
||||
EXPECT_EQ(batch[1].size(), 1);
|
||||
EXPECT_EQ(batch[0][0].fvalue, 0.0f);
|
||||
EXPECT_EQ(batch[0][0].index, 0);
|
||||
EXPECT_EQ(batch[1][0].fvalue, 3.0f);
|
||||
EXPECT_EQ(batch[1][0].index, 1);
|
||||
}
|
||||
|
||||
@ -2,6 +2,7 @@ import numpy as np
|
||||
import xgboost as xgb
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
@ -86,3 +87,64 @@ Arrow specification.'''
|
||||
'x': cudf.Series([True, False, True, True, True])})
|
||||
with pytest.raises(Exception):
|
||||
dtrain = xgb.DMatrix(X_boolean, label=y_boolean)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_cudf_training(self):
|
||||
from cudf import DataFrame as df
|
||||
import pandas as pd
|
||||
X = pd.DataFrame(np.random.randn(50, 10))
|
||||
y = pd.DataFrame(np.random.randn(50))
|
||||
weights = np.random.random(50)
|
||||
cudf_weights = df.from_pandas(pd.DataFrame(weights))
|
||||
base_margin = np.random.random(50)
|
||||
cudf_base_margin = df.from_pandas(pd.DataFrame(base_margin))
|
||||
|
||||
evals_result_cudf = {}
|
||||
dtrain_cudf = xgb.DMatrix(df.from_pandas(X), df.from_pandas(y), weight=cudf_weights,
|
||||
base_margin=cudf_base_margin)
|
||||
xgb.train({'gpu_id': 0}, dtrain_cudf, evals=[(dtrain_cudf, "train")],
|
||||
evals_result=evals_result_cudf)
|
||||
evals_result_np = {}
|
||||
dtrain_np = xgb.DMatrix(X, y, weight=weights, base_margin=base_margin)
|
||||
xgb.train({}, dtrain_np, evals=[(dtrain_np, "train")],
|
||||
evals_result=evals_result_np)
|
||||
assert np.array_equal(evals_result_cudf["train"]["rmse"], evals_result_np["train"]["rmse"])
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_cudf_metainfo(self):
|
||||
from cudf import DataFrame as df
|
||||
import pandas as pd
|
||||
n = 100
|
||||
X = np.random.random((n, 2))
|
||||
dmat_cudf = xgb.DMatrix(X)
|
||||
dmat = xgb.DMatrix(X)
|
||||
floats = np.random.random(n)
|
||||
uints = np.array([4, 2, 8]).astype("uint32")
|
||||
cudf_floats = df.from_pandas(pd.DataFrame(floats))
|
||||
cudf_uints = df.from_pandas(pd.DataFrame(uints))
|
||||
dmat.set_float_info('weight', floats)
|
||||
dmat.set_float_info('label', floats)
|
||||
dmat.set_float_info('base_margin', floats)
|
||||
dmat.set_uint_info('group', uints)
|
||||
dmat_cudf.set_interface_info('weight', cudf_floats)
|
||||
dmat_cudf.set_interface_info('label', cudf_floats)
|
||||
dmat_cudf.set_interface_info('base_margin', cudf_floats)
|
||||
dmat_cudf.set_interface_info('group', cudf_uints)
|
||||
|
||||
# Test setting info with cudf DataFrame
|
||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||
dmat_cudf.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
||||
|
||||
# Test setting info with cudf Series
|
||||
dmat_cudf.set_interface_info('weight', cudf_floats[cudf_floats.columns[0]])
|
||||
dmat_cudf.set_interface_info('label', cudf_floats[cudf_floats.columns[0]])
|
||||
dmat_cudf.set_interface_info('base_margin', cudf_floats[cudf_floats.columns[0]])
|
||||
dmat_cudf.set_interface_info('group', cudf_uints[cudf_uints.columns[0]])
|
||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cudf.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cudf.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||
dmat_cudf.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cudf.get_uint_info('group_ptr'))
|
||||
|
||||
97
tests/python-gpu/test_from_cupy.py
Normal file
97
tests/python-gpu/test_from_cupy.py
Normal file
@ -0,0 +1,97 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
|
||||
def dmatrix_from_cupy(input_type, missing=np.NAN):
|
||||
'''Test constructing DMatrix from cupy'''
|
||||
import cupy as cp
|
||||
|
||||
kRows = 80
|
||||
kCols = 3
|
||||
|
||||
np_X = np.random.randn(kRows, kCols).astype(dtype=input_type)
|
||||
X = cp.array(np_X)
|
||||
X[5, 0] = missing
|
||||
X[3, 1] = missing
|
||||
y = cp.random.randn(kRows).astype(dtype=input_type)
|
||||
dtrain = xgb.DMatrix(X, missing=missing, label=y)
|
||||
assert dtrain.num_col() == kCols
|
||||
assert dtrain.num_row() == kRows
|
||||
return dtrain
|
||||
|
||||
|
||||
class TestFromArrayInterface:
|
||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
||||
Arrow specification.'''
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_from_cupy(self):
|
||||
'''Test constructing DMatrix from cupy'''
|
||||
import cupy as cp
|
||||
dmatrix_from_cupy(np.float32, np.NAN)
|
||||
dmatrix_from_cupy(np.float64, np.NAN)
|
||||
|
||||
dmatrix_from_cupy(np.uint8, 2)
|
||||
dmatrix_from_cupy(np.uint32, 3)
|
||||
dmatrix_from_cupy(np.uint64, 4)
|
||||
|
||||
dmatrix_from_cupy(np.int8, 2)
|
||||
dmatrix_from_cupy(np.int32, -2)
|
||||
dmatrix_from_cupy(np.int64, -3)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
X = cp.random.randn(2, 2, dtype="float32")
|
||||
dtrain = xgb.DMatrix(X, label=X)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_cupy_training(self):
|
||||
import cupy as cp
|
||||
X = cp.random.randn(50, 10, dtype="float32")
|
||||
y = cp.random.randn(50, dtype="float32")
|
||||
weights = np.random.random(50)
|
||||
cupy_weights = cp.array(weights)
|
||||
base_margin = np.random.random(50)
|
||||
cupy_base_margin = cp.array(base_margin)
|
||||
|
||||
evals_result_cupy = {}
|
||||
dtrain_cp = xgb.DMatrix(X, y, weight=cupy_weights, base_margin=cupy_base_margin)
|
||||
xgb.train({'gpu_id': 0}, dtrain_cp, evals=[(dtrain_cp, "train")],
|
||||
evals_result=evals_result_cupy)
|
||||
evals_result_np = {}
|
||||
dtrain_np = xgb.DMatrix(cp.asnumpy(X), cp.asnumpy(y), weight=weights,
|
||||
base_margin=base_margin)
|
||||
xgb.train({'gpu_id': 0}, dtrain_np, evals=[(dtrain_np, "train")],
|
||||
evals_result=evals_result_np)
|
||||
assert np.array_equal(evals_result_cupy["train"]["rmse"], evals_result_np["train"]["rmse"])
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_cupy_metainfo(self):
|
||||
import cupy as cp
|
||||
n = 100
|
||||
X = np.random.random((n, 2))
|
||||
dmat_cupy = xgb.DMatrix(X)
|
||||
dmat = xgb.DMatrix(X)
|
||||
floats = np.random.random(n)
|
||||
uints = np.array([4, 2, 8]).astype("uint32")
|
||||
cupy_floats = cp.array(floats)
|
||||
cupy_uints = cp.array(uints)
|
||||
dmat.set_float_info('weight', floats)
|
||||
dmat.set_float_info('label', floats)
|
||||
dmat.set_float_info('base_margin', floats)
|
||||
dmat.set_uint_info('group', uints)
|
||||
dmat_cupy.set_interface_info('weight', cupy_floats)
|
||||
dmat_cupy.set_interface_info('label', cupy_floats)
|
||||
dmat_cupy.set_interface_info('base_margin', cupy_floats)
|
||||
dmat_cupy.set_interface_info('group', cupy_uints)
|
||||
|
||||
# Test setting info with cupy
|
||||
assert np.array_equal(dmat.get_float_info('weight'), dmat_cupy.get_float_info('weight'))
|
||||
assert np.array_equal(dmat.get_float_info('label'), dmat_cupy.get_float_info('label'))
|
||||
assert np.array_equal(dmat.get_float_info('base_margin'),
|
||||
dmat_cupy.get_float_info('base_margin'))
|
||||
assert np.array_equal(dmat.get_uint_info('group_ptr'), dmat_cupy.get_uint_info('group_ptr'))
|
||||
@ -48,6 +48,15 @@ def no_cudf():
|
||||
'reason': 'CUDF is not installed'}
|
||||
|
||||
|
||||
def no_cupy():
|
||||
reason = 'cupy is not installed.'
|
||||
try:
|
||||
import cupy as _ # noqa
|
||||
return {'condition': False, 'reason': reason}
|
||||
except ImportError:
|
||||
return {'condition': True, 'reason': reason}
|
||||
|
||||
|
||||
def no_dask_cudf():
|
||||
reason = 'dask_cudf is not installed.'
|
||||
try:
|
||||
|
||||
@ -16,10 +16,9 @@ if [ ${TASK} == "python_test" ]; then
|
||||
echo "-------------------------------"
|
||||
conda activate python3
|
||||
python --version
|
||||
conda install numpy scipy pandas matplotlib scikit-learn
|
||||
conda install numpy scipy pandas matplotlib scikit-learn dask
|
||||
|
||||
python -m pip install graphviz pytest pytest-cov codecov
|
||||
python -m pip install dask distributed dask[dataframe]
|
||||
python -m pip install datatable
|
||||
python -m pytest -v --fulltrace -s tests/python --cov=python-package/xgboost || exit -1
|
||||
codecov
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user