Complete cudf support. (#4850)
* Handles missing value. * Accept all floating point and integer types. * Move to cudf 9.0 API. * Remove requirement on `null_count`. * Arbitrary column types support.
This commit is contained in:
parent
125bcec62e
commit
5374f52531
@ -7,7 +7,6 @@ import collections
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from collections.abc import Mapping # Python 3
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
import math
|
||||
import ctypes
|
||||
import os
|
||||
import re
|
||||
@ -235,28 +234,13 @@ def _extract_interface_from_cudf(df, is_info):
|
||||
'columnar format. For other libraries please ' +
|
||||
'refer to specific API.')
|
||||
|
||||
def get_interface(obj):
|
||||
return obj.mem.__cuda_array_interface__
|
||||
|
||||
array_interfaces = []
|
||||
for col in df.columns:
|
||||
data = df[col].data
|
||||
array_interfaces.append(get_interface(data))
|
||||
|
||||
validity_masks = []
|
||||
for col in df.columns:
|
||||
if df[col].has_null_mask:
|
||||
mask_interface = get_interface(df[col].nullmask)
|
||||
mask_interface['null_count'] = df[col].null_count
|
||||
validity_masks.append(mask_interface)
|
||||
else:
|
||||
validity_masks.append(False)
|
||||
|
||||
for i in range(len(df.columns)):
|
||||
col_interface = array_interfaces[i]
|
||||
mask_interface = validity_masks[i]
|
||||
if mask_interface is not False:
|
||||
col_interface['mask'] = mask_interface
|
||||
data = df[col]
|
||||
interface = data.__cuda_array_interface__
|
||||
if data.has_null_mask:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
array_interfaces.append(interface)
|
||||
|
||||
if is_info:
|
||||
array_interfaces = array_interfaces[0]
|
||||
@ -369,22 +353,6 @@ def _maybe_dt_array(array):
|
||||
return array
|
||||
|
||||
|
||||
def _check_data(data, missing):
|
||||
'''The missing value applies only to np.ndarray.'''
|
||||
is_invalid = (not isinstance(data, np.ndarray)) and (missing is not None)
|
||||
is_invalid = is_invalid and not math.isnan(missing)
|
||||
if is_invalid:
|
||||
raise ValueError(
|
||||
'missing value only applies to dense input, ' +
|
||||
'e.g. `numpy.ndarray`.' +
|
||||
' For a possibly sparse data type: ' + str(type(data)) +
|
||||
' please remove missing values or set it to nan.' +
|
||||
' Current missing value is set to: ' + str(missing))
|
||||
if isinstance(data, list):
|
||||
warnings.warn('Initializing DMatrix from List is deprecated.',
|
||||
DeprecationWarning)
|
||||
|
||||
|
||||
class DMatrix(object):
|
||||
"""Data Matrix used in XGBoost.
|
||||
|
||||
@ -443,7 +411,8 @@ class DMatrix(object):
|
||||
self._feature_types = feature_types
|
||||
return
|
||||
|
||||
_check_data(data, missing)
|
||||
if isinstance(data, list):
|
||||
raise TypeError('Input data can not be a list.')
|
||||
|
||||
data, feature_names, feature_types = _maybe_pandas_data(data,
|
||||
feature_names,
|
||||
@ -472,7 +441,7 @@ class DMatrix(object):
|
||||
elif isinstance(data, DataTable):
|
||||
self._init_from_dt(data, nthread)
|
||||
elif _use_columnar_initializer(data):
|
||||
self._init_from_columnar(data)
|
||||
self._init_from_columnar(data, missing)
|
||||
else:
|
||||
try:
|
||||
csr = scipy.sparse.csr_matrix(data)
|
||||
@ -599,15 +568,18 @@ class DMatrix(object):
|
||||
nthread))
|
||||
self.handle = handle
|
||||
|
||||
def _init_from_columnar(self, df):
|
||||
def _init_from_columnar(self, df, missing):
|
||||
'''Initialize DMatrix from columnar memory format.
|
||||
|
||||
'''
|
||||
interfaces = _extract_interface_from_cudf(df, False)
|
||||
handle = ctypes.c_void_p()
|
||||
has_missing = missing is not None
|
||||
missing = missing if has_missing else np.nan
|
||||
_check_call(
|
||||
_LIB.XGDMatrixCreateFromArrayInterfaces(interfaces,
|
||||
ctypes.byref(handle)))
|
||||
_LIB.XGDMatrixCreateFromArrayInterfaces(
|
||||
interfaces, ctypes.c_int32(has_missing),
|
||||
ctypes.c_float(missing), ctypes.byref(handle)))
|
||||
self.handle = handle
|
||||
|
||||
def __del__(self):
|
||||
|
||||
@ -190,11 +190,11 @@ int XGDMatrixCreateFromDataIter(
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromArrayInterfaces(
|
||||
char const* c_json_strs, DMatrixHandle* out) {
|
||||
char const* c_json_strs, bst_int has_missing, bst_float missing, DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
std::string json_str {c_json_strs};
|
||||
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
|
||||
source->CopyFrom(json_str);
|
||||
source->CopyFrom(json_str, has_missing, missing);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@ -345,6 +345,8 @@ struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
||||
// Replacement of allocator with custom backend should occur here
|
||||
template <typename T>
|
||||
using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocatorImpl<T>;
|
||||
/*! Be careful that the initialization constructor is a no-op, which means calling
|
||||
* `vec.resize(n, 1)` won't initialize the memory region to 1. */
|
||||
template <typename T>
|
||||
using XGBCachingDeviceAllocator = detail::XGBCachingDeviceAllocatorImpl<T>;
|
||||
/** \brief Specialisation of thrust device vector using custom allocator. */
|
||||
|
||||
@ -9,11 +9,11 @@
|
||||
|
||||
#include <xgboost/base.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -26,6 +26,23 @@ XGBOOST_DEVICE inline float Sigmoid(float x) {
|
||||
return 1.0f / (1.0f + expf(-x));
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Equality test for both integer and floating point.
|
||||
*/
|
||||
template <typename T, typename U>
|
||||
XGBOOST_DEVICE constexpr bool CloseTo(T a, U b) {
|
||||
using Casted =
|
||||
typename std::conditional<
|
||||
std::is_floating_point<T>::value || std::is_floating_point<U>::value,
|
||||
double,
|
||||
typename std::conditional<
|
||||
std::is_signed<T>::value || std::is_signed<U>::value,
|
||||
int64_t,
|
||||
uint64_t>::type>::type;
|
||||
return std::is_floating_point<Casted>::value ?
|
||||
std::abs(static_cast<Casted>(a) -static_cast<Casted>(b)) < 1e-6 : a == b;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Do inplace softmax transformaton on start to end
|
||||
*
|
||||
@ -119,14 +136,34 @@ inline static bool CmpSecond(const std::pair<float, unsigned> &a,
|
||||
// check nan
|
||||
bool CheckNAN(double v);
|
||||
#else
|
||||
|
||||
// Redefined here to workaround a VC bug that doesn't support overloadng for integer
|
||||
// types.
|
||||
template <typename T>
|
||||
inline bool CheckNAN(T v) {
|
||||
#ifdef _MSC_VER
|
||||
return (_isnan(v) != 0);
|
||||
#else
|
||||
return std::isnan(v);
|
||||
#endif // _MSC_VER
|
||||
XGBOOST_DEVICE typename std::enable_if<
|
||||
std::numeric_limits<T>::is_integer, bool>::type
|
||||
CheckNAN(T) {
|
||||
return false;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bool inline CheckNAN(float x) {
|
||||
#if (defined(_WIN32) || defined(__CUDA_ARCH__)) && \
|
||||
!defined(__MINGW64__) && !defined(__MINGW32__) && !defined(__CYGWIN__)
|
||||
return isnan(x);
|
||||
#else
|
||||
return std::isnan(x);
|
||||
#endif // has c++11 std::isnan
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bool inline CheckNAN(double x) {
|
||||
#if (defined(_WIN32) || defined(__CUDA_ARCH__)) && \
|
||||
!defined(__MINGW64__) && !defined(__MINGW32__) && !defined(__CYGWIN__)
|
||||
return isnan(x);
|
||||
#else
|
||||
return std::isnan(x);
|
||||
#endif // has c++11 std::isnan
|
||||
}
|
||||
|
||||
#endif // XGBOOST_STRICT_R_MODE_
|
||||
|
||||
// GPU version is not uploaded in CRAN anyway.
|
||||
|
||||
@ -12,18 +12,20 @@
|
||||
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "../common/span.h"
|
||||
#include "../common/bitfield.h"
|
||||
|
||||
namespace xgboost {
|
||||
// A view over __array_interface__
|
||||
template <typename T>
|
||||
struct Columnar {
|
||||
using mask_type = unsigned char;
|
||||
using index_type = int32_t;
|
||||
|
||||
common::Span<float> data;
|
||||
common::Span<T> data;
|
||||
RBitField8 valid;
|
||||
int32_t size;
|
||||
int32_t null_count;
|
||||
};
|
||||
|
||||
// Common errors in parsing columnar format.
|
||||
@ -49,13 +51,7 @@ struct ColumnarErrors {
|
||||
static char const* Version() {
|
||||
return "Only version 1 of __cuda_array_interface__ is being supported.";
|
||||
}
|
||||
static char const* toFloat() {
|
||||
return "Please convert the input into float32 first.";
|
||||
}
|
||||
static char const* toUInt() {
|
||||
return "Please convert the Group into unsigned 32 bit integers first.";
|
||||
}
|
||||
static char const* ofType(std::string type) {
|
||||
static char const* ofType(std::string const& type) {
|
||||
static std::string str;
|
||||
str.clear();
|
||||
str += " should be of ";
|
||||
@ -63,10 +59,43 @@ struct ColumnarErrors {
|
||||
str += " type.";
|
||||
return str.c_str();
|
||||
}
|
||||
static std::string UnknownTypeStr(std::string const& typestr) {
|
||||
return "typestr from array interface: " + typestr + " is not supported.";
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(trivialfis): Abstract this into a class that accept a json
|
||||
// object and turn it into an array (for cupy and numba).
|
||||
class ArrayInterfaceHandler {
|
||||
public:
|
||||
template <typename T>
|
||||
static constexpr char TypeChar() {
|
||||
return
|
||||
(std::is_floating_point<T>::value ? 'f' :
|
||||
(std::is_integral<T>::value ?
|
||||
(std::is_signed<T>::value ? 'i' : 'u') : '\0'));
|
||||
}
|
||||
|
||||
static std::string TypeStr(char c) {
|
||||
switch (c) {
|
||||
case 't':
|
||||
return "Bit field";
|
||||
case 'b':
|
||||
return "Boolean";
|
||||
case 'i':
|
||||
return "Integer";
|
||||
case 'u':
|
||||
return "Unsigned integer";
|
||||
case 'f':
|
||||
return "Floating point";
|
||||
default:
|
||||
LOG(FATAL) << "Invalid type code: " << c << " in typestr of input array interface.";
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename PtrType>
|
||||
PtrType GetPtrFromArrayData(std::map<std::string, Json> const& obj) {
|
||||
static PtrType GetPtrFromArrayData(std::map<std::string, Json> const& obj) {
|
||||
if (obj.find("data") == obj.cend()) {
|
||||
LOG(FATAL) << "Empty data passed in.";
|
||||
}
|
||||
@ -78,5 +107,128 @@ PtrType GetPtrFromArrayData(std::map<std::string, Json> const& obj) {
|
||||
return p_data;
|
||||
}
|
||||
|
||||
static void Validate(std::map<std::string, Json> const& array) {
|
||||
if (array.find("version") == array.cend()) {
|
||||
LOG(FATAL) << "Missing version field for array interface";
|
||||
}
|
||||
auto version = get<Integer const>(array.at("version"));
|
||||
CHECK_EQ(version, 1) << ColumnarErrors::Version();
|
||||
|
||||
if (array.find("typestr") == array.cend()) {
|
||||
LOG(FATAL) << "Missing typestr field for array interface";
|
||||
}
|
||||
auto typestr = get<String const>(array.at("typestr"));
|
||||
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
|
||||
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
|
||||
|
||||
if (array.find("shape") == array.cend()) {
|
||||
LOG(FATAL) << "Missing shape field for array interface";
|
||||
}
|
||||
if (array.find("data") == array.cend()) {
|
||||
LOG(FATAL) << "Missing data field for array interface";
|
||||
}
|
||||
}
|
||||
|
||||
// Find null mask (validity mask) field
|
||||
// Mask object is also an array interface, but with different requirements.
|
||||
static void ExtractMask(std::map<std::string, Json> const& column,
|
||||
common::Span<RBitField8::value_type>* p_out) {
|
||||
auto& s_mask = *p_out;
|
||||
if (column.find("mask") != column.cend()) {
|
||||
auto const& j_mask = get<Object const>(column.at("mask"));
|
||||
Validate(j_mask);
|
||||
|
||||
auto p_mask = GetPtrFromArrayData<RBitField8::value_type*>(j_mask);
|
||||
|
||||
auto j_shape = get<Array const>(j_mask.at("shape"));
|
||||
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(j_shape.front()) % 8, 0) <<
|
||||
"Length of validity mask must be a multiple of 8 bytes.";
|
||||
int64_t size = get<Integer>(j_shape.at(0)) *
|
||||
sizeof(unsigned char) / sizeof(RBitField8::value_type);
|
||||
auto typestr = get<String const>(j_mask.at("typestr"));
|
||||
|
||||
if (typestr.at(1) == 't') {
|
||||
CHECK_EQ(typestr.at(2), '1') << "There can be only 1 bit in each entry of bitfield.";
|
||||
} else if (typestr.at(1) == 'i') {
|
||||
CHECK_EQ(typestr.at(2), '1') << "mask with integer type should be of 1 byte per integer.";
|
||||
} else {
|
||||
LOG(FATAL) << "mask must be of integer type or bit field type.";
|
||||
}
|
||||
|
||||
// For now this is just 1
|
||||
int64_t const type_length = typestr.at(2) - 48;
|
||||
s_mask = {p_mask, size / type_length};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
|
||||
Validate(column);
|
||||
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
CHECK_EQ(typestr.at(1), TypeChar<T>())
|
||||
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
||||
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
|
||||
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
||||
|
||||
auto j_shape = get<Array const>(column.at("shape"));
|
||||
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
|
||||
if (column.find("strides") != column.cend()) {
|
||||
auto strides = get<Array const>(column.at("strides"));
|
||||
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), 4) << ColumnarErrors::Contigious();
|
||||
}
|
||||
|
||||
auto length = get<Integer const>(j_shape.at(0));
|
||||
|
||||
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
||||
return common::Span<T>{p_data, length};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static Columnar<T> ExtractArray(std::map<std::string, Json> const& column) {
|
||||
common::Span<T> s_data { ArrayInterfaceHandler::ExtractData<T>(column) };
|
||||
|
||||
Columnar<T> foreign_col;
|
||||
foreign_col.data = s_data;
|
||||
foreign_col.size = s_data.size();
|
||||
|
||||
common::Span<RBitField8::value_type> s_mask;
|
||||
ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
||||
|
||||
foreign_col.valid = RBitField8(s_mask);
|
||||
|
||||
return foreign_col;
|
||||
}
|
||||
};
|
||||
|
||||
#define DISPATCH_TYPE(__dispatched_func, __typestr, ...) { \
|
||||
if (__typestr.at(1) == 'f' && __typestr.at(2) == '4') { \
|
||||
__dispatched_func<float>(__VA_ARGS__); \
|
||||
} else if (__typestr.at(1) == 'f' && __typestr.at(2) == '8') { \
|
||||
__dispatched_func<double>(__VA_ARGS__); \
|
||||
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '1') { \
|
||||
__dispatched_func<int8_t>(__VA_ARGS__); \
|
||||
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '2') { \
|
||||
__dispatched_func<int16_t>(__VA_ARGS__); \
|
||||
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '4') { \
|
||||
__dispatched_func<int32_t>(__VA_ARGS__); \
|
||||
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '8') { \
|
||||
__dispatched_func<int64_t>(__VA_ARGS__); \
|
||||
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '1') { \
|
||||
__dispatched_func<uint8_t>(__VA_ARGS__); \
|
||||
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '2') { \
|
||||
__dispatched_func<uint16_t>(__VA_ARGS__); \
|
||||
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '4') { \
|
||||
__dispatched_func<uint32_t>(__VA_ARGS__); \
|
||||
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '8') { \
|
||||
__dispatched_func<uint64_t>(__VA_ARGS__); \
|
||||
} else { \
|
||||
LOG(FATAL) << ColumnarErrors::UnknownTypeStr(__typestr); \
|
||||
} \
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_COLUMNAR_H_
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
* Copyright 2019 by XGBoost Contributors
|
||||
*
|
||||
* \file data.cu
|
||||
* \brief Handles setting metainfo from array interface.
|
||||
*/
|
||||
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
@ -12,75 +12,60 @@
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
template <typename T>
|
||||
void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<float>* out) {
|
||||
auto SetDeviceToPtr = [](void* ptr) {
|
||||
cudaPointerAttributes attr;
|
||||
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
|
||||
int32_t ptr_device = attr.device;
|
||||
dh::safe_cuda(cudaSetDevice(ptr_device));
|
||||
return ptr_device;
|
||||
};
|
||||
|
||||
common::Span<T> s_data { ArrayInterfaceHandler::ExtractData<T>(column) };
|
||||
auto ptr_device = SetDeviceToPtr(s_data.data());
|
||||
thrust::device_ptr<T> p_src {s_data.data()};
|
||||
|
||||
auto length = s_data.size();
|
||||
out->SetDevice(ptr_device);
|
||||
out->Resize(length);
|
||||
|
||||
auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
|
||||
thrust::copy(p_src, p_src + length, p_dst);
|
||||
}
|
||||
|
||||
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
Json j_arr = Json::Load({interface_str.c_str(), interface_str.size()});
|
||||
auto const& j_arr_obj = get<Object>(j_arr);
|
||||
std::string key {c_key};
|
||||
auto version = get<Integer const>(j_arr_obj.at("version"));
|
||||
CHECK_EQ(version, 1) << ColumnarErrors::Version();
|
||||
ArrayInterfaceHandler::Validate(j_arr_obj);
|
||||
if (j_arr_obj.find("mask") != j_arr_obj.cend()) {
|
||||
LOG(FATAL) << "Meta info " << key << " should be dense, found validity mask";
|
||||
}
|
||||
auto const& typestr = get<String const>(j_arr_obj.at("typestr"));
|
||||
|
||||
auto typestr = get<String const>(j_arr_obj.at("typestr"));
|
||||
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
|
||||
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
|
||||
|
||||
auto j_shape = get<Array const>(j_arr_obj.at("shape"));
|
||||
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
auto length = get<Integer const>(j_shape.at(0));
|
||||
CHECK_GT(length, 0) << "Label set cannot be empty.";
|
||||
|
||||
if (j_arr_obj.find("strides") != j_arr_obj.cend()) {
|
||||
auto strides = get<Array const>(j_arr_obj.at("strides"));
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), 4) << ColumnarErrors::Contigious();
|
||||
}
|
||||
|
||||
float* p_data = GetPtrFromArrayData<float*>(j_arr_obj);
|
||||
|
||||
cudaPointerAttributes attr;
|
||||
dh::safe_cuda(cudaPointerGetAttributes(&attr, p_data));
|
||||
int32_t ptr_device = attr.device;
|
||||
dh::safe_cuda(cudaSetDevice(ptr_device));
|
||||
|
||||
thrust::device_ptr<float> p_src {p_data};
|
||||
|
||||
HostDeviceVector<float>* dst;
|
||||
if (key == "root_index") {
|
||||
LOG(FATAL) << "root index for columnar data is not supported.";
|
||||
} else if (key == "label") {
|
||||
dst = &labels_;
|
||||
CHECK_EQ(typestr.at(1), 'f') << "Label"
|
||||
<< ColumnarErrors::ofType("floating point");
|
||||
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
|
||||
DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &labels_);
|
||||
} else if (key == "weight") {
|
||||
dst = &weights_;
|
||||
CHECK_EQ(typestr.at(1), 'f') << "Weight"
|
||||
<< ColumnarErrors::ofType("floating point");;
|
||||
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
|
||||
DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &weights_);
|
||||
} else if (key == "base_margin") {
|
||||
dst = &base_margin_;
|
||||
CHECK_EQ(typestr.at(1), 'f') << "Base Margin"
|
||||
<< ColumnarErrors::ofType("floating point");
|
||||
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
|
||||
DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &base_margin_);
|
||||
} else if (key == "group") {
|
||||
CHECK_EQ(typestr.at(1), 'u') << "Group"
|
||||
<< ColumnarErrors::ofType("unsigned 32 bit integers");
|
||||
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toUInt();
|
||||
// Ranking is not performed on device.
|
||||
auto s_data = ArrayInterfaceHandler::ExtractData<uint32_t>(j_arr_obj);
|
||||
thrust::device_ptr<uint32_t> p_src {s_data.data()};
|
||||
|
||||
auto length = s_data.size();
|
||||
group_ptr_.resize(length + 1);
|
||||
group_ptr_[0] = 0;
|
||||
// Ranking is not performed on device.
|
||||
thrust::copy(p_src, p_src + length, group_ptr_.begin() + 1);
|
||||
for (size_t i = 1; i < group_ptr_.size(); ++i) {
|
||||
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
|
||||
}
|
||||
std::partial_sum(group_ptr_.begin(), group_ptr_.end(), group_ptr_.begin());
|
||||
|
||||
return;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown metainfo: " << key;
|
||||
}
|
||||
dst->SetDevice(ptr_device);
|
||||
dst->Resize(length);
|
||||
auto p_dst = thrust::device_pointer_cast(dst->DevicePointer());
|
||||
thrust::copy(p_src, p_src + length, p_dst);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -123,15 +123,29 @@ const SparsePage& SimpleCSRSource::Value() const {
|
||||
/*!
|
||||
* Please be careful that, in official specification, the only three required fields are
|
||||
* `shape', `version' and `typestr'. Any other is optional, including `data'. But here
|
||||
* we have two additional requirements for input data:
|
||||
* we have one additional requirements for input data:
|
||||
*
|
||||
* - `data' field is required, passing in an empty dataset is not accepted, as most (if
|
||||
* not all) of our algorithms don't have test for empty dataset. An error is better
|
||||
* than a crash.
|
||||
*
|
||||
* - `null_count' is required when `mask' is presented. We can compute `null_count'
|
||||
* ourselves and copy the result back to host for memory allocation. But it's in the
|
||||
* specification of Apache Arrow hence it should be readily available,
|
||||
* Missing value handling:
|
||||
* Missing value is specified:
|
||||
* - Ignore the validity mask from columnar format.
|
||||
* - Remove entries that equals to missing value.
|
||||
* - missing = NaN:
|
||||
* - Remove entries that is NaN
|
||||
* - missing != NaN:
|
||||
* - Check for NaN entries, throw an error if found.
|
||||
* Missing value is not specified:
|
||||
* - Remove entries that is specifed as by validity mask.
|
||||
* - Remove NaN entries.
|
||||
*
|
||||
* What if invalid value from dataframe is 0 but I specify missing=NaN in XGBoost? Since
|
||||
* validity mask is ignored, all 0s are preserved in XGBoost.
|
||||
*
|
||||
* FIXME(trivialfis): Put above into document after we have a consistent way for
|
||||
* processing input data.
|
||||
*
|
||||
* Sample input:
|
||||
* [
|
||||
@ -160,100 +174,29 @@ const SparsePage& SimpleCSRSource::Value() const {
|
||||
* false
|
||||
* ],
|
||||
* "typestr": "|i1",
|
||||
* "version": 1,
|
||||
* "null_count": 1
|
||||
* "version": 1
|
||||
* }
|
||||
* }
|
||||
* ]
|
||||
*/
|
||||
void SimpleCSRSource::CopyFrom(std::string const& cuda_interfaces_str) {
|
||||
void SimpleCSRSource::CopyFrom(std::string const& cuda_interfaces_str,
|
||||
bool has_missing, float missing) {
|
||||
Json interfaces = Json::Load({cuda_interfaces_str.c_str(),
|
||||
cuda_interfaces_str.size()});
|
||||
std::vector<Json> const& columns = get<Array>(interfaces);
|
||||
size_t n_columns = columns.size();
|
||||
CHECK_GT(n_columns, 0);
|
||||
CHECK_GT(n_columns, 0) << "Number of columns must not be greater than 0.";
|
||||
|
||||
std::vector<Columnar> foreign_cols(n_columns);
|
||||
for (size_t i = 0; i < columns.size(); ++i) {
|
||||
CHECK(IsA<Object>(columns[i]));
|
||||
auto const& column = get<Object const>(columns[i]);
|
||||
|
||||
auto version = get<Integer const>(column.at("version"));
|
||||
CHECK_EQ(version, 1) << ColumnarErrors::Version();
|
||||
|
||||
// Find null mask (validity mask) field
|
||||
// Mask object is also an array interface, but with different requirements.
|
||||
|
||||
// TODO(trivialfis): Abstract this into a class that accept a json
|
||||
// object and turn it into an array (for cupy and numba).
|
||||
common::Span<RBitField8::value_type> s_mask;
|
||||
int32_t null_count {0};
|
||||
if (column.find("mask") != column.cend()) {
|
||||
auto const& j_mask = get<Object const>(column.at("mask"));
|
||||
auto p_mask = GetPtrFromArrayData<RBitField8::value_type*>(j_mask);
|
||||
|
||||
auto j_shape = get<Array const>(j_mask.at("shape"));
|
||||
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(j_shape.front()) % 8, 0) <<
|
||||
"Length of validity map must be a multiple of 8 bytes.";
|
||||
int64_t size = get<Integer>(j_shape.at(0)) *
|
||||
sizeof(unsigned char) / sizeof(RBitField8::value_type);
|
||||
s_mask = {p_mask, size};
|
||||
auto typestr = get<String const>(j_mask.at("typestr"));
|
||||
auto const& typestr = get<String const>(columns[0]["typestr"]);
|
||||
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
|
||||
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
|
||||
CHECK_EQ(typestr.at(1), 'i') << "mask" << ColumnarErrors::ofType("unsigned char");
|
||||
CHECK_EQ(typestr.at(2), '1') << "mask" << ColumnarErrors::toUInt();
|
||||
|
||||
CHECK(j_mask.find("null_count") != j_mask.cend()) <<
|
||||
"Column with null mask must include null_count as "
|
||||
"part of mask object for XGBoost.";
|
||||
null_count = get<Integer const>(j_mask.at("null_count"));
|
||||
}
|
||||
|
||||
// Find data field
|
||||
if (column.find("data") == column.cend()) {
|
||||
LOG(FATAL) << "Empty dataset passed in.";
|
||||
}
|
||||
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
|
||||
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
|
||||
CHECK_EQ(typestr.at(1), 'f') << "data" << ColumnarErrors::ofType("floating point");
|
||||
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
|
||||
|
||||
auto j_shape = get<Array const>(column.at("shape"));
|
||||
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
|
||||
if (column.find("strides") != column.cend()) {
|
||||
auto strides = get<Array const>(column.at("strides"));
|
||||
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), 4) << ColumnarErrors::Contigious();
|
||||
}
|
||||
|
||||
auto length = get<Integer const>(j_shape.at(0));
|
||||
|
||||
float* p_data = GetPtrFromArrayData<float*>(column);
|
||||
common::Span<float> s_data {p_data, length};
|
||||
|
||||
foreign_cols[i].data = s_data;
|
||||
foreign_cols[i].valid = RBitField8(s_mask);
|
||||
foreign_cols[i].size = s_data.size();
|
||||
foreign_cols[i].null_count = null_count;
|
||||
}
|
||||
|
||||
info.num_col_ = n_columns;
|
||||
info.num_row_ = foreign_cols[0].size;
|
||||
for (size_t i = 0; i < n_columns; ++i) {
|
||||
CHECK_EQ(foreign_cols[0].size, foreign_cols[i].size);
|
||||
info.num_nonzero_ += foreign_cols[i].data.size() - foreign_cols[i].null_count;
|
||||
}
|
||||
|
||||
this->FromDeviceColumnar(foreign_cols);
|
||||
this->FromDeviceColumnar(columns, has_missing, missing);
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
void SimpleCSRSource::FromDeviceColumnar(std::vector<Columnar> cols) {
|
||||
void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns,
|
||||
bool has_missing, float missing) {
|
||||
LOG(FATAL) << "XGBoost version is not compiled with GPU support";
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
@ -12,102 +12,190 @@
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "simple_csr_source.h"
|
||||
#include "columnar.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
template <size_t kBlockThreads>
|
||||
__global__ void CountValidKernel(common::Span<Columnar const> columns,
|
||||
int32_t const n_rows,
|
||||
common::Span<size_t> offsets) {
|
||||
// One block for a column
|
||||
auto const bid = blockIdx.x;
|
||||
auto const tid = threadIdx.x;
|
||||
if (bid >= columns.size()) {
|
||||
template <typename T>
|
||||
__global__ void CountValidKernel(Columnar<T> const column,
|
||||
bool has_missing, float missing,
|
||||
int32_t* flag, common::Span<size_t> offsets) {
|
||||
auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
bool const missing_is_nan = common::CheckNAN(missing);
|
||||
|
||||
if (tid >= column.size) {
|
||||
return;
|
||||
}
|
||||
RBitField8 const mask = columns[bid].valid;
|
||||
for (auto r = tid; r < n_rows; r += kBlockThreads) {
|
||||
if (mask.Data() == nullptr || mask.Check(r)) {
|
||||
atomicAdd(reinterpret_cast<BitFieldAtomicType*>(&offsets[r+1]),
|
||||
static_cast<BitFieldAtomicType>(1));
|
||||
RBitField8 const mask = column.valid;
|
||||
|
||||
if (!has_missing) {
|
||||
if ((mask.Data() == nullptr || mask.Check(tid)) &&
|
||||
!common::CheckNAN(column.data[tid])) {
|
||||
offsets[tid+1] += 1;
|
||||
}
|
||||
} else if (missing_is_nan) {
|
||||
if (!common::CheckNAN(column.data[tid])) {
|
||||
offsets[tid+1] += 1;
|
||||
}
|
||||
} else {
|
||||
if (!common::CloseTo(column.data[tid], missing)) {
|
||||
offsets[tid+1] += 1;
|
||||
}
|
||||
if (common::CheckNAN(column.data[tid])) {
|
||||
*flag = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void CreateCSRKernel(Columnar const column,
|
||||
int32_t colid,
|
||||
common::Span<size_t> offsets,
|
||||
common::Span<Entry> out_data) {
|
||||
auto tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
template <typename T>
|
||||
__device__ void AssignValue(T fvalue, int32_t colid,
|
||||
common::Span<size_t> out_offsets, common::Span<Entry> out_data) {
|
||||
auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
int32_t oid = out_offsets[tid];
|
||||
out_data[oid].fvalue = fvalue;
|
||||
out_data[oid].index = colid;
|
||||
out_offsets[tid] += 1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CreateCSRKernel(Columnar<T> const column,
|
||||
int32_t colid, bool has_missing, float missing,
|
||||
common::Span<size_t> offsets, common::Span<Entry> out_data) {
|
||||
auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
if (column.size <= tid) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (column.valid.Data() == nullptr || column.valid.Check(tid)) {
|
||||
int32_t oid = offsets[tid];
|
||||
out_data[oid].fvalue = column.data[tid];
|
||||
out_data[oid].index = colid;
|
||||
offsets[tid] += 1;
|
||||
bool const missing_is_nan = common::CheckNAN(missing);
|
||||
if (!has_missing) {
|
||||
// no missing value is specified
|
||||
if ((column.valid.Data() == nullptr || column.valid.Check(tid)) &&
|
||||
!common::CheckNAN(column.data[tid])) {
|
||||
AssignValue(column.data[tid], colid, offsets, out_data);
|
||||
}
|
||||
} else if (missing_is_nan) {
|
||||
// specified missing value, but it's NaN
|
||||
if (!common::CheckNAN(column.data[tid])) {
|
||||
AssignValue(column.data[tid], colid, offsets, out_data);
|
||||
}
|
||||
} else {
|
||||
// specified missing value, and it's not NaN
|
||||
if (!common::CloseTo(column.data[tid], missing)) {
|
||||
AssignValue(column.data[tid], colid, offsets, out_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SimpleCSRSource::FromDeviceColumnar(std::vector<Columnar> cols) {
|
||||
uint64_t const n_cols = cols.size();
|
||||
uint64_t const n_rows = cols[0].size;
|
||||
template <typename T>
|
||||
void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
|
||||
bool has_missing, float missing,
|
||||
HostDeviceVector<size_t>* out_offset,
|
||||
dh::caching_device_vector<int32_t>* out_d_flag,
|
||||
uint32_t* out_n_rows) {
|
||||
int32_t constexpr kThreads = 256;
|
||||
auto const& j_column = j_columns[column_id];
|
||||
auto const& column_obj = get<Object const>(j_column);
|
||||
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj);
|
||||
uint32_t const n_rows = foreign_column.size;
|
||||
|
||||
auto ptr = cols[0].data.data();
|
||||
auto ptr = foreign_column.data.data();
|
||||
int32_t device = dh::CudaGetPointerDevice(ptr);
|
||||
CHECK_NE(device, -1);
|
||||
|
||||
for (int32_t i = 1; i < n_cols; ++i) {
|
||||
auto ptr = cols[i].data.data();
|
||||
int32_t ptr_device = dh::CudaGetPointerDevice(ptr);
|
||||
CHECK_EQ(device, ptr_device)
|
||||
<< "GPU ID at 0^th column: " << device << ", "
|
||||
<< "GPU ID at column " << i << ": " << ptr_device;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
|
||||
page_.offset.SetDevice(device);
|
||||
page_.offset.Resize(info.num_row_ + 1);
|
||||
if (column_id == 0) {
|
||||
out_offset->SetDevice(device);
|
||||
out_offset->Resize(n_rows + 1);
|
||||
}
|
||||
CHECK_EQ(out_offset->DeviceIdx(), device)
|
||||
<< "All columns should use the same device.";
|
||||
CHECK_EQ(out_offset->Size(), n_rows + 1)
|
||||
<< "All columns should have same number of rows.";
|
||||
|
||||
page_.data.SetDevice(device);
|
||||
page_.data.Resize(info.num_nonzero_);
|
||||
common::Span<size_t> s_offsets = out_offset->DeviceSpan();
|
||||
|
||||
auto s_data = page_.data.DeviceSpan();
|
||||
auto s_offsets = page_.offset.DeviceSpan();
|
||||
CHECK_EQ(s_offsets.size(), n_rows + 1);
|
||||
int32_t const kBlocks = common::DivRoundUp(n_rows, kThreads);
|
||||
CountValidKernel<T><<<kBlocks, kThreads>>>(
|
||||
foreign_column,
|
||||
has_missing, missing,
|
||||
out_d_flag->data().get(), s_offsets);
|
||||
*out_n_rows = n_rows;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows,
|
||||
bool has_missing, float missing,
|
||||
dh::device_vector<size_t>* tmp_offset, common::Span<Entry> s_data) {
|
||||
int32_t constexpr kThreads = 256;
|
||||
dh::device_vector<Columnar> d_cols(cols);
|
||||
auto s_d_cols = dh::ToSpan(d_cols);
|
||||
auto const& j_column = j_columns[column_id];
|
||||
auto const& column_obj = get<Object const>(j_column);
|
||||
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj);
|
||||
int32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
|
||||
CreateCSRKernel<T><<<kBlocks, kThreads>>>(foreign_column, column_id, has_missing, missing,
|
||||
dh::ToSpan(*tmp_offset), s_data);
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaMemset(s_offsets.data(), 0, sizeof(int32_t) * (n_rows + 1)));
|
||||
void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns,
|
||||
bool has_missing, float missing) {
|
||||
auto const n_cols = columns.size();
|
||||
int32_t constexpr kThreads = 256;
|
||||
|
||||
CountValidKernel<kThreads><<<n_cols, kThreads>>>(s_d_cols, n_rows, s_offsets);
|
||||
dh::caching_device_vector<int32_t> d_flag;
|
||||
if (!common::CheckNAN(missing)) {
|
||||
d_flag.resize(1);
|
||||
thrust::fill(d_flag.begin(), d_flag.end(), 0);
|
||||
}
|
||||
uint32_t n_rows {0};
|
||||
for (size_t i = 0; i < n_cols; ++i) {
|
||||
auto const& typestr = get<String const>(columns[i]["typestr"]);
|
||||
DISPATCH_TYPE(CountValid, typestr,
|
||||
columns, i, has_missing, missing, &(this->page_.offset), &d_flag, &n_rows);
|
||||
}
|
||||
// don't pay for what you don't use.
|
||||
if (!common::CheckNAN(missing)) {
|
||||
int32_t flag {0};
|
||||
dh::safe_cuda(cudaMemcpy(&flag, d_flag.data().get(), sizeof(int32_t), cudaMemcpyDeviceToHost));
|
||||
CHECK_EQ(flag, 0) << "missing value is specifed but input data contains NaN.";
|
||||
}
|
||||
|
||||
info.num_col_ = n_cols;
|
||||
info.num_row_ = n_rows;
|
||||
|
||||
auto s_offsets = this->page_.offset.DeviceSpan();
|
||||
thrust::device_ptr<size_t> p_offsets(s_offsets.data());
|
||||
CHECK_GE(s_offsets.size(), n_rows + 1);
|
||||
|
||||
thrust::inclusive_scan(p_offsets, p_offsets + n_rows + 1, p_offsets);
|
||||
// Created for building csr matrix, where we need to change index
|
||||
// after processing each column.
|
||||
dh::device_vector<size_t> tmp_offset(page_.offset.Size());
|
||||
thrust::copy(p_offsets, p_offsets + n_rows + 1, tmp_offset.begin());
|
||||
// Created for building csr matrix, where we need to change index after processing each
|
||||
// column.
|
||||
dh::device_vector<size_t> tmp_offset(this->page_.offset.Size());
|
||||
dh::safe_cuda(cudaMemcpy(tmp_offset.data().get(), s_offsets.data(),
|
||||
s_offsets.size_bytes(), cudaMemcpyDeviceToDevice));
|
||||
|
||||
// We can use null_count from columnar data format, but that will add a non-standard
|
||||
// entry in the array interface, also involves accumulating from all columns. Invoking
|
||||
// one copy seems easier.
|
||||
this->info.num_nonzero_ = tmp_offset.back();
|
||||
|
||||
int device = this->page_.offset.DeviceIdx();
|
||||
this->page_.data.SetDevice(device);
|
||||
this->page_.data.Resize(this->info.num_nonzero_);
|
||||
auto s_data = this->page_.data.DeviceSpan();
|
||||
|
||||
int32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
|
||||
|
||||
for (size_t col = 0; col < n_cols; ++col) {
|
||||
CreateCSRKernel<<<kBlocks, kThreads>>>(d_cols[col], col, dh::ToSpan(tmp_offset), s_data);
|
||||
for (size_t i = 0; i < n_cols; ++i) {
|
||||
auto const& typestr = get<String const>(columns[i]["typestr"]);
|
||||
DISPATCH_TYPE(CreateCSR, typestr, columns, i, n_rows,
|
||||
has_missing, missing, &tmp_offset, s_data);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "columnar.h"
|
||||
|
||||
@ -52,8 +53,11 @@ class SimpleCSRSource : public DataSource<SparsePage> {
|
||||
/*!
|
||||
* \brief copy content of data from foreign **GPU** columnar buffer.
|
||||
* \param interfaces_str JSON representation of cuda array interfaces.
|
||||
* \param has_missing Whether did users supply their own missing value.
|
||||
* \param missing The missing value set by users.
|
||||
*/
|
||||
void CopyFrom(std::string const& cuda_interfaces_str);
|
||||
void CopyFrom(std::string const& cuda_interfaces_str, bool has_missing,
|
||||
bst_float missing = std::numeric_limits<float>::quiet_NaN());
|
||||
/*!
|
||||
* \brief Load data from binary stream.
|
||||
* \param fi the pointer to load data from.
|
||||
@ -76,9 +80,12 @@ class SimpleCSRSource : public DataSource<SparsePage> {
|
||||
private:
|
||||
/*!
|
||||
* \brief copy content of data from foreign GPU columnar buffer.
|
||||
* \param cols foreign columns data buffer.
|
||||
* \param columns JSON representation of array interfaces.
|
||||
* \param missing specifed missing value
|
||||
*/
|
||||
void FromDeviceColumnar(std::vector<Columnar> cols);
|
||||
void FromDeviceColumnar(std::vector<Json> const& columns,
|
||||
bool has_missing = false,
|
||||
float missing = std::numeric_limits<float>::quiet_NaN());
|
||||
/*! \brief internal variable, used to support iterator interface */
|
||||
bool at_first_{true};
|
||||
};
|
||||
|
||||
@ -7,11 +7,13 @@
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
TEST(MetaInfo, FromInterface) {
|
||||
cudaSetDevice(0);
|
||||
constexpr size_t kRows = 16;
|
||||
|
||||
thrust::device_vector<float> d_data(kRows);
|
||||
template <typename T>
|
||||
std::string PrepareData(std::string typestr, thrust::device_vector<T>* out) {
|
||||
constexpr size_t kRows = 16;
|
||||
out->resize(kRows);
|
||||
auto& d_data = *out;
|
||||
|
||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||
d_data[i] = i * 2.0;
|
||||
}
|
||||
@ -22,7 +24,7 @@ TEST(MetaInfo, FromInterface) {
|
||||
column["shape"] = Array(j_shape);
|
||||
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(4)))});
|
||||
column["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
column["typestr"] = String("<f4");
|
||||
column["typestr"] = String(typestr);
|
||||
|
||||
auto p_d_data = dh::Raw(d_data);
|
||||
std::vector<Json> j_data {
|
||||
@ -34,6 +36,15 @@ TEST(MetaInfo, FromInterface) {
|
||||
Json::Dump(column, &ss);
|
||||
std::string str = ss.str();
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
TEST(MetaInfo, FromInterface) {
|
||||
cudaSetDevice(0);
|
||||
thrust::device_vector<float> d_data;
|
||||
|
||||
std::string str = PrepareData<float>("<f4", &d_data);
|
||||
|
||||
MetaInfo info;
|
||||
info.SetInfo("label", str.c_str());
|
||||
|
||||
@ -53,5 +64,22 @@ TEST(MetaInfo, FromInterface) {
|
||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||
ASSERT_EQ(h_base_margin[i], d_data[i]);
|
||||
}
|
||||
|
||||
EXPECT_ANY_THROW({info.SetInfo("group", str.c_str());});
|
||||
}
|
||||
|
||||
TEST(MetaInfo, Group) {
|
||||
cudaSetDevice(0);
|
||||
thrust::device_vector<uint32_t> d_data;
|
||||
std::string str = PrepareData<uint32_t>("<u4", &d_data);
|
||||
|
||||
MetaInfo info;
|
||||
|
||||
info.SetInfo("group", str.c_str());
|
||||
auto const& h_group = info.group_ptr_;
|
||||
ASSERT_EQ(h_group.size(), d_data.size() + 1);
|
||||
for (size_t i = 1; i < h_group.size(); ++i) {
|
||||
ASSERT_EQ(h_group[i], d_data[i-1] + h_group[i-1]) << "i: " << i;
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
@ -8,17 +8,48 @@
|
||||
#include "../../../src/common/bitfield.h"
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "../../../src/data/simple_csr_source.h"
|
||||
#include "../../../src/data/columnar.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TEST(SimpleCSRSource, FromColumnarDense) {
|
||||
constexpr size_t kRows = 16;
|
||||
TEST(ArrayInterfaceHandler, Error) {
|
||||
constexpr size_t kRows {16};
|
||||
Json column { Object() };
|
||||
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
|
||||
column["shape"] = Array(j_shape);
|
||||
std::vector<Json> j_data {
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
|
||||
Json(Boolean(false))};
|
||||
|
||||
auto const& column_obj = get<Object>(column);
|
||||
// missing version
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
||||
column["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
// missing data
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
||||
column["data"] = j_data;
|
||||
// missing typestr
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
||||
column["typestr"] = String("<f4");
|
||||
// nullptr is not valid
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
||||
thrust::device_vector<float> d_data(kRows);
|
||||
j_data = {Json(Integer(reinterpret_cast<Integer::Int>(d_data.data().get()))),
|
||||
Json(Boolean(false))};
|
||||
column["data"] = j_data;
|
||||
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Json GenerateDenseColumn(std::string const& typestr, size_t kRows,
|
||||
thrust::device_vector<T>* out_d_data) {
|
||||
auto& d_data = *out_d_data;
|
||||
Json column { Object() };
|
||||
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
|
||||
column["shape"] = Array(j_shape);
|
||||
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(4)))});
|
||||
|
||||
thrust::device_vector<float> d_data(kRows);
|
||||
d_data.resize(kRows);
|
||||
for (size_t i = 0; i < d_data.size(); ++i) {
|
||||
d_data[i] = i * 2.0;
|
||||
}
|
||||
@ -31,34 +62,86 @@ TEST(SimpleCSRSource, FromColumnarDense) {
|
||||
column["data"] = j_data;
|
||||
|
||||
column["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
column["typestr"] = String("<f4");
|
||||
Json column_arr {Array{std::vector<Json>{column}}};
|
||||
column["typestr"] = String(typestr);
|
||||
return column;
|
||||
}
|
||||
|
||||
TEST(SimpleCSRSource, FromColumnarDense) {
|
||||
constexpr size_t kRows {16};
|
||||
constexpr size_t kCols {2};
|
||||
std::vector<Json> columns;
|
||||
thrust::device_vector<float> d_data_0(kRows);
|
||||
thrust::device_vector<int32_t> d_data_1(kRows);
|
||||
columns.emplace_back(GenerateDenseColumn<float>("<f4", kRows, &d_data_0));
|
||||
columns.emplace_back(GenerateDenseColumn<int32_t>("<i4", kRows, &d_data_1));
|
||||
|
||||
Json column_arr {columns};
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
|
||||
// no missing value
|
||||
{
|
||||
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
|
||||
source->CopyFrom(str.c_str());
|
||||
source->CopyFrom(str.c_str(), false);
|
||||
|
||||
auto const& data = source->page_.data.HostVector();
|
||||
auto const& offset = source->page_.offset.HostVector();
|
||||
for (size_t i = 0; i < kRows; ++i) {
|
||||
auto e = data[i];
|
||||
ASSERT_NEAR(e.fvalue, i * 2.0, kRtEps);
|
||||
ASSERT_EQ(e.index, 0); // feature 0
|
||||
for (size_t i = 0; i < kRows; i++) {
|
||||
auto const idx = i * kCols;
|
||||
auto const e_0 = data.at(idx);
|
||||
ASSERT_NEAR(e_0.fvalue, i * 2.0, kRtEps) << "idx: " << idx;
|
||||
ASSERT_EQ(e_0.index, 0); // feature 0
|
||||
|
||||
auto e_1 = data.at(idx+1);
|
||||
ASSERT_NEAR(e_1.fvalue, i * 2.0, kRtEps);
|
||||
ASSERT_EQ(e_1.index, 1); // feature 1
|
||||
}
|
||||
ASSERT_EQ(offset.back(), 16);
|
||||
ASSERT_EQ(offset.back(), kRows * kCols);
|
||||
for (size_t i = 0; i < kRows + 1; ++i) {
|
||||
ASSERT_EQ(offset[i], i);
|
||||
ASSERT_EQ(offset[i], i * kCols);
|
||||
}
|
||||
ASSERT_EQ(source->info.num_row_, kRows);
|
||||
ASSERT_EQ(source->info.num_col_, kCols);
|
||||
}
|
||||
|
||||
// with missing value specified
|
||||
{
|
||||
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
|
||||
source->CopyFrom(str.c_str(), true, 4.0);
|
||||
|
||||
auto const& data = source->page_.data.HostVector();
|
||||
auto const& offset = source->page_.offset.HostVector();
|
||||
ASSERT_EQ(data.size(), kRows * kCols - 2);
|
||||
ASSERT_NEAR(data[4].fvalue, 6.0, kRtEps); // kCols * 2
|
||||
ASSERT_EQ(offset.back(), 30);
|
||||
for (size_t i = 3; i < kRows + 1; ++i) {
|
||||
ASSERT_EQ(offset[i], (i - 1) * 2);
|
||||
}
|
||||
ASSERT_EQ(source->info.num_row_, kRows);
|
||||
ASSERT_EQ(source->info.num_col_, kCols);
|
||||
}
|
||||
|
||||
{
|
||||
// no missing value, but has NaN
|
||||
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
|
||||
d_data_0[3] = std::numeric_limits<float>::quiet_NaN();
|
||||
ASSERT_TRUE(std::isnan(d_data_0[3])); // removes 6.0
|
||||
source->CopyFrom(str.c_str(), false);
|
||||
|
||||
auto const& data = source->page_.data.HostVector();
|
||||
auto const& offset = source->page_.offset.HostVector();
|
||||
ASSERT_EQ(data.size(), kRows * kCols - 1);
|
||||
ASSERT_NEAR(data[7].fvalue, 8.0, kRtEps);
|
||||
ASSERT_EQ(source->info.num_row_, kRows);
|
||||
ASSERT_EQ(source->info.num_col_, kCols);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SimpleCSRSource, FromColumnarWithEmptyRows) {
|
||||
// In this test we construct a data storage similar to cudf
|
||||
constexpr size_t kRows = 102;
|
||||
constexpr size_t kCols = 24;
|
||||
constexpr size_t kMissingRows = 3;
|
||||
|
||||
std::vector<Json> v_columns (kCols);
|
||||
std::vector<dh::device_vector<float>> columns_data(kCols);
|
||||
@ -90,6 +173,7 @@ TEST(SimpleCSRSource, FromColumnarWithEmptyRows) {
|
||||
// Construct the mask object.
|
||||
col["mask"] = Object();
|
||||
auto& j_mask = col["mask"];
|
||||
j_mask["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
auto& mask_storage = column_bitfields[i];
|
||||
mask_storage.resize(16); // 16 bytes
|
||||
|
||||
@ -111,7 +195,6 @@ TEST(SimpleCSRSource, FromColumnarWithEmptyRows) {
|
||||
Json(Boolean(false))};
|
||||
j_mask["shape"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(16)))});
|
||||
j_mask["typestr"] = String("|i1");
|
||||
j_mask["null_count"] = Json(Integer(static_cast<Integer::Int>(kMissingRows)));
|
||||
}
|
||||
|
||||
Json column_arr {Array(v_columns)};
|
||||
@ -119,7 +202,7 @@ TEST(SimpleCSRSource, FromColumnarWithEmptyRows) {
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
|
||||
source->CopyFrom(str.c_str());
|
||||
source->CopyFrom(str.c_str(), false);
|
||||
|
||||
auto const& data = source->page_.data.HostVector();
|
||||
auto const& offset = source->page_.offset.HostVector();
|
||||
@ -131,6 +214,7 @@ TEST(SimpleCSRSource, FromColumnarWithEmptyRows) {
|
||||
ASSERT_NEAR(data[j].fvalue, i - 1, kRtEps);
|
||||
}
|
||||
}
|
||||
ASSERT_EQ(source->info.num_row_, kRows);
|
||||
}
|
||||
|
||||
TEST(SimpleCSRSource, FromColumnarSparse) {
|
||||
@ -149,6 +233,8 @@ TEST(SimpleCSRSource, FromColumnarSparse) {
|
||||
for (size_t j = 0; j < mask.size(); ++j) {
|
||||
mask[j] = ~0;
|
||||
}
|
||||
// the 2^th entry of first column is invalid
|
||||
// [0 0 0 0 0 1 0 0]
|
||||
mask[0] = ~(kUCOne << 2);
|
||||
}
|
||||
{
|
||||
@ -159,6 +245,8 @@ TEST(SimpleCSRSource, FromColumnarSparse) {
|
||||
for (size_t j = 0; j < mask.size(); ++j) {
|
||||
mask[j] = ~0;
|
||||
}
|
||||
// the 19^th entry of second column is invalid
|
||||
// [~0~], [~0~], [0 0 0 0 1 0 0 0]
|
||||
mask[2] = ~(kUCOne << 3);
|
||||
}
|
||||
|
||||
@ -186,12 +274,12 @@ TEST(SimpleCSRSource, FromColumnarSparse) {
|
||||
|
||||
column["mask"] = Object();
|
||||
auto& j_mask = column["mask"];
|
||||
j_mask["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
j_mask["data"] = std::vector<Json>{
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(column_bitfields[c].data().get()))),
|
||||
Json(Boolean(false))};
|
||||
j_mask["shape"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(8)))});
|
||||
j_mask["typestr"] = String("|i1");
|
||||
j_mask["null_count"] = Json(Integer(static_cast<Integer::Int>(1)));
|
||||
}
|
||||
|
||||
Json column_arr {Array(j_columns)};
|
||||
@ -200,8 +288,9 @@ TEST(SimpleCSRSource, FromColumnarSparse) {
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
|
||||
{
|
||||
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
|
||||
source->CopyFrom(str.c_str());
|
||||
source->CopyFrom(str.c_str(), false);
|
||||
|
||||
auto const& data = source->page_.data.HostVector();
|
||||
auto const& offset = source->page_.offset.HostVector();
|
||||
@ -213,4 +302,50 @@ TEST(SimpleCSRSource, FromColumnarSparse) {
|
||||
ASSERT_EQ(data[37].fvalue, 19);
|
||||
}
|
||||
|
||||
{
|
||||
// with missing value
|
||||
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
|
||||
source->CopyFrom(str.c_str(), true, /*missing=*/2.0);
|
||||
|
||||
auto const& data = source->page_.data.HostVector();
|
||||
ASSERT_NE(data[4].fvalue, 2.0);
|
||||
}
|
||||
|
||||
{
|
||||
// no missing value, but has NaN
|
||||
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
|
||||
columns_data[0][4] = std::numeric_limits<float>::quiet_NaN(); // 0^th column 4^th row
|
||||
ASSERT_TRUE(std::isnan(columns_data[0][4]));
|
||||
source->CopyFrom(str.c_str(), false);
|
||||
|
||||
auto const& data = source->page_.data.HostVector();
|
||||
auto const& offset = source->page_.offset.HostVector();
|
||||
// Two invalid entries and one NaN, in CSC
|
||||
// 0^th column: 0, 1, 4, 5, 6, ..., kRows
|
||||
// 1^th column: 0, 1, 2, 3, ..., 19, 21, ..., kRows
|
||||
// Turning it into CSR:
|
||||
// | 0, 0 | 1, 1 | 2 | 3, 3 | 4 | ...
|
||||
ASSERT_EQ(data.size(), kRows * kCols - 3);
|
||||
ASSERT_EQ(data[4].index, 1); // from 1^th column
|
||||
ASSERT_EQ(data[5].fvalue, 3.0);
|
||||
ASSERT_EQ(data[7].index, 1); // from 1^th column
|
||||
ASSERT_EQ(data[7].fvalue, 4.0);
|
||||
|
||||
ASSERT_EQ(data[offset[2]].fvalue, 2.0);
|
||||
ASSERT_EQ(data[offset[4]].fvalue, 4.0);
|
||||
}
|
||||
|
||||
{
|
||||
// with NaN as missing value
|
||||
// NaN is already set up by above test
|
||||
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
|
||||
source->CopyFrom(str.c_str(), true,
|
||||
/*missing=*/std::numeric_limits<float>::quiet_NaN());
|
||||
|
||||
auto const& data = source->page_.data.HostVector();
|
||||
ASSERT_EQ(data.size(), kRows * kCols - 1);
|
||||
ASSERT_EQ(data[8].fvalue, 4.0);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/python
|
||||
import xgboost as xgb
|
||||
import numpy as np
|
||||
|
||||
xgb.rabit.init()
|
||||
|
||||
@ -59,6 +60,7 @@ X = [
|
||||
4415.50,22731.62,1.00,55.00,0.00,499.94,22.00,0.58,67.00,0.21,341.72,16.00,0.00,965.07,
|
||||
17.00,138.41,0.00,0.00,1.00,0.14,1.00,0.02,0.35,1.69,369.00,1300.00,25.00,0.00,0.01,
|
||||
0.00,0.00,0.00,0.00,52.00,8.00]]
|
||||
X = np.array(X)
|
||||
y = [1, 0]
|
||||
|
||||
dtrain = xgb.DMatrix(X, label=y)
|
||||
|
||||
@ -6,6 +6,35 @@ sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
|
||||
|
||||
def dmatrix_from_cudf(input_type, missing=np.NAN):
|
||||
'''Test constructing DMatrix from cudf'''
|
||||
import cudf
|
||||
import pandas as pd
|
||||
|
||||
kRows = 80
|
||||
kCols = 3
|
||||
|
||||
na = np.random.randn(kRows, kCols)
|
||||
na[:, 0:2] = na[:, 0:2].astype(input_type)
|
||||
|
||||
na[5, 0] = missing
|
||||
na[3, 1] = missing
|
||||
|
||||
pa = pd.DataFrame({'0': na[:, 0],
|
||||
'1': na[:, 1],
|
||||
'2': na[:, 2].astype(np.int32)})
|
||||
|
||||
np_label = np.random.randn(kRows).astype(input_type)
|
||||
pa_label = pd.DataFrame(np_label)
|
||||
|
||||
cd: cudf.DataFrame = cudf.from_pandas(pa)
|
||||
cd_label: cudf.DataFrame = cudf.from_pandas(pa_label)
|
||||
|
||||
dtrain = xgb.DMatrix(cd, missing=missing, label=cd_label)
|
||||
assert dtrain.num_col() == kCols
|
||||
assert dtrain.num_row() == kRows
|
||||
|
||||
|
||||
class TestFromColumnar:
|
||||
'''Tests for constructing DMatrix from data structure conforming Apache
|
||||
Arrow specification.'''
|
||||
@ -13,30 +42,13 @@ Arrow specification.'''
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_from_cudf(self):
|
||||
'''Test constructing DMatrix from cudf'''
|
||||
import cudf
|
||||
import pandas as pd
|
||||
dmatrix_from_cudf(np.float32, np.NAN)
|
||||
dmatrix_from_cudf(np.float64, np.NAN)
|
||||
|
||||
kRows = 80
|
||||
kCols = 2
|
||||
dmatrix_from_cudf(np.uint8, 2)
|
||||
dmatrix_from_cudf(np.uint32, 3)
|
||||
dmatrix_from_cudf(np.uint64, 4)
|
||||
|
||||
na = np.random.randn(kRows, kCols).astype(np.float32)
|
||||
na[3, 1] = np.NAN
|
||||
na[5, 0] = np.NAN
|
||||
|
||||
pa = pd.DataFrame(na)
|
||||
|
||||
np_label = np.random.randn(kRows).astype(np.float32)
|
||||
pa_label = pd.DataFrame(np_label)
|
||||
|
||||
names = []
|
||||
|
||||
for i in range(0, kCols):
|
||||
names.append(str(i))
|
||||
pa.columns = names
|
||||
|
||||
cd: cudf.DataFrame = cudf.from_pandas(pa)
|
||||
cd_label: cudf.DataFrame = cudf.from_pandas(pa_label)
|
||||
|
||||
dtrain = xgb.DMatrix(cd, label=cd_label)
|
||||
assert dtrain.num_col() == kCols
|
||||
assert dtrain.num_row() == kRows
|
||||
dmatrix_from_cudf(np.int8, 2)
|
||||
dmatrix_from_cudf(np.int32, -2)
|
||||
dmatrix_from_cudf(np.int64, -3)
|
||||
|
||||
@ -67,17 +67,17 @@ class TestBasic(unittest.TestCase):
|
||||
def test_np_view(self):
|
||||
# Sliced Float32 array
|
||||
y = np.array([12, 34, 56], np.float32)[::2]
|
||||
from_view = xgb.DMatrix([], label=y).get_label()
|
||||
from_array = xgb.DMatrix([], label=y + 0).get_label()
|
||||
from_view = xgb.DMatrix(np.array([[]]), label=y).get_label()
|
||||
from_array = xgb.DMatrix(np.array([[]]), label=y + 0).get_label()
|
||||
assert (from_view.shape == from_array.shape)
|
||||
assert (from_view == from_array).all()
|
||||
|
||||
# Sliced UInt array
|
||||
z = np.array([12, 34, 56], np.uint32)[::2]
|
||||
dmat = xgb.DMatrix([])
|
||||
dmat = xgb.DMatrix(np.array([[]]))
|
||||
dmat.set_uint_info('root_index', z)
|
||||
from_view = dmat.get_uint_info('root_index')
|
||||
dmat = xgb.DMatrix([])
|
||||
dmat = xgb.DMatrix(np.array([[]]))
|
||||
dmat.set_uint_info('root_index', z + 0)
|
||||
from_array = dmat.get_uint_info('root_index')
|
||||
assert (from_view.shape == from_array.shape)
|
||||
@ -256,7 +256,7 @@ class TestBasic(unittest.TestCase):
|
||||
assert dm.num_row() == 5
|
||||
assert dm.num_col() == 5
|
||||
|
||||
data = np.matrix([[1, 2], [3, 4]])
|
||||
data = np.array([[1, 2], [3, 4]])
|
||||
dm = xgb.DMatrix(data)
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 2
|
||||
@ -430,4 +430,3 @@ class TestBasicPathLike(unittest.TestCase):
|
||||
|
||||
# invalid values raise Type error
|
||||
self.assertRaises(TypeError, xgb.compat.os_fspath, 123)
|
||||
|
||||
|
||||
@ -69,8 +69,8 @@ class TestUpdaters(unittest.TestCase):
|
||||
nan = np.nan
|
||||
param = {'missing': nan, 'tree_method': 'hist'}
|
||||
model = xgb.XGBRegressor(**param)
|
||||
X = [[6.18827160e+05, 1.73000000e+02], [6.37345679e+05, nan],
|
||||
[6.38888889e+05, nan], [6.28086420e+05, nan]]
|
||||
X = np.array([[6.18827160e+05, 1.73000000e+02], [6.37345679e+05, nan],
|
||||
[6.38888889e+05, nan], [6.28086420e+05, nan]])
|
||||
y = [1000000., 0., 0., 500000.]
|
||||
w = [0, 0, 1, 0]
|
||||
model.fit(X, y, sample_weight=w)
|
||||
|
||||
@ -19,7 +19,7 @@ pytestmark = pytest.mark.skipif(**tm.no_dask())
|
||||
|
||||
def run_train():
|
||||
# Contains one label equal to rank
|
||||
dmat = xgb.DMatrix([[0]], label=[xgb.rabit.get_rank()])
|
||||
dmat = xgb.DMatrix(np.array([[0]]), label=[xgb.rabit.get_rank()])
|
||||
bst = xgb.train({"eta": 1.0, "lambda": 0.0}, dmat, 1)
|
||||
pred = bst.predict(dmat)
|
||||
expected_result = np.average(range(xgb.rabit.get_world_size()))
|
||||
@ -78,7 +78,7 @@ def test_get_local_data(client):
|
||||
|
||||
def run_sklearn():
|
||||
# Contains one label equal to rank
|
||||
X = [[0]]
|
||||
X = np.array([[0]])
|
||||
y = [xgb.rabit.get_rank()]
|
||||
model = xgb.XGBRegressor(learning_rate=1.0)
|
||||
model.fit(X, y)
|
||||
|
||||
@ -393,7 +393,8 @@ def test_sklearn_nfolds_cv():
|
||||
nfolds = 5
|
||||
skf = StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=seed)
|
||||
|
||||
cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, seed=seed, as_pandas=True)
|
||||
cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds,
|
||||
seed=seed, as_pandas=True)
|
||||
cv2 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds,
|
||||
folds=skf, seed=seed, as_pandas=True)
|
||||
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user