Use dynamic types for array interface columns instead of templates (#5108)
This commit is contained in:
parent
b915788708
commit
3d04a8cc97
@ -14,21 +14,9 @@
|
|||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
|
|
||||||
#include "../common/bitfield.h"
|
#include "../common/bitfield.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
// A view over __array_interface__
|
|
||||||
template <typename T>
|
|
||||||
struct Columnar {
|
|
||||||
using mask_type = unsigned char;
|
|
||||||
using index_type = int32_t;
|
|
||||||
|
|
||||||
common::Span<T> data;
|
|
||||||
RBitField8 valid;
|
|
||||||
int32_t size;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Common errors in parsing columnar format.
|
// Common errors in parsing columnar format.
|
||||||
struct ColumnarErrors {
|
struct ColumnarErrors {
|
||||||
static char const* Contigious() {
|
static char const* Contigious() {
|
||||||
@ -97,8 +85,8 @@ struct ColumnarErrors {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string UnSupportedType(std::string const& typestr) {
|
static std::string UnSupportedType(const char (&typestr)[3]) {
|
||||||
return TypeStr(typestr.at(1)) + " is not supported.";
|
return TypeStr(typestr[1]) + " is not supported.";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -200,6 +188,19 @@ class ArrayInterfaceHandler {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static size_t ExtractLength(std::map<std::string, Json> const& column) {
|
||||||
|
auto j_shape = get<Array const>(column.at("shape"));
|
||||||
|
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
|
||||||
|
auto typestr = get<String const>(column.at("typestr"));
|
||||||
|
if (column.find("strides") != column.cend()) {
|
||||||
|
auto strides = get<Array const>(column.at("strides"));
|
||||||
|
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
|
||||||
|
CHECK_EQ(get<Integer>(strides.at(0)), typestr.at(2) - '0')
|
||||||
|
<< ColumnarErrors::Contigious();
|
||||||
|
}
|
||||||
|
|
||||||
|
return static_cast<size_t>(get<Integer const>(j_shape.at(0)));
|
||||||
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
|
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
|
||||||
Validate(column);
|
Validate(column);
|
||||||
@ -210,70 +211,102 @@ class ArrayInterfaceHandler {
|
|||||||
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
|
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
|
||||||
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
||||||
|
|
||||||
auto j_shape = get<Array const>(column.at("shape"));
|
|
||||||
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
|
|
||||||
|
|
||||||
if (column.find("strides") != column.cend()) {
|
auto length = ExtractLength(column);
|
||||||
auto strides = get<Array const>(column.at("strides"));
|
|
||||||
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
|
|
||||||
CHECK_EQ(get<Integer>(strides.at(0)), sizeof(T)) << ColumnarErrors::Contigious();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto length = static_cast<size_t>(get<Integer const>(j_shape.at(0)));
|
|
||||||
|
|
||||||
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
||||||
return common::Span<T>{p_data, length};
|
return common::Span<T>{p_data, length};
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
// A view over __array_interface__
|
||||||
static Columnar<T> ExtractArray(std::map<std::string, Json> const& column) {
|
class Columnar {
|
||||||
common::Span<T> s_data { ArrayInterfaceHandler::ExtractData<T>(column) };
|
using mask_type = unsigned char;
|
||||||
|
using index_type = int32_t;
|
||||||
|
|
||||||
Columnar<T> foreign_col;
|
public:
|
||||||
foreign_col.data = s_data;
|
explicit Columnar(std::map<std::string, Json> const& column) {
|
||||||
foreign_col.size = s_data.size();
|
ArrayInterfaceHandler::Validate(column);
|
||||||
|
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
||||||
|
CHECK(data) << "Column is null";
|
||||||
|
size = ArrayInterfaceHandler::ExtractLength(column);
|
||||||
|
|
||||||
common::Span<RBitField8::value_type> s_mask;
|
common::Span<RBitField8::value_type> s_mask;
|
||||||
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
||||||
|
|
||||||
foreign_col.valid = RBitField8(s_mask);
|
valid = RBitField8(s_mask);
|
||||||
|
|
||||||
if (s_mask.data()) {
|
if (s_mask.data()) {
|
||||||
CHECK_EQ(n_bits, foreign_col.data.size())
|
CHECK_EQ(n_bits, size)
|
||||||
<< "Shape of bit mask doesn't match data shape. "
|
<< "Shape of bit mask doesn't match data shape. "
|
||||||
<< "XGBoost doesn't support internal broadcasting.";
|
<< "XGBoost doesn't support internal broadcasting.";
|
||||||
}
|
}
|
||||||
|
auto typestr = get<String const>(column.at("typestr"));
|
||||||
return foreign_col;
|
type[0] = typestr.at(0);
|
||||||
|
type[1] = typestr.at(1);
|
||||||
|
type[2] = typestr.at(2);
|
||||||
|
this->CheckType();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CheckType() const {
|
||||||
|
if (type[1] == 'f' && type[2] == '4') {
|
||||||
|
return;
|
||||||
|
} else if (type[1] == 'f' && type[2] == '8') {
|
||||||
|
return;
|
||||||
|
} else if (type[1] == 'i' && type[2] == '1') {
|
||||||
|
return;
|
||||||
|
} else if (type[1] == 'i' && type[2] == '2') {
|
||||||
|
return;
|
||||||
|
} else if (type[1] == 'i' && type[2] == '4') {
|
||||||
|
return;
|
||||||
|
} else if (type[1] == 'i' && type[2] == '8') {
|
||||||
|
return;
|
||||||
|
} else if (type[1] == 'u' && type[2] == '1') {
|
||||||
|
return;
|
||||||
|
} else if (type[1] == 'u' && type[2] == '2') {
|
||||||
|
return;
|
||||||
|
} else if (type[1] == 'u' && type[2] == '4') {
|
||||||
|
return;
|
||||||
|
} else if (type[1] == 'u' && type[2] == '8') {
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << ColumnarErrors::UnSupportedType(type);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE float GetElement(size_t idx) const {
|
||||||
|
if (type[1] == 'f' && type[2] == '4') {
|
||||||
|
return reinterpret_cast<float*>(data)[idx];
|
||||||
|
} else if (type[1] == 'f' && type[2] == '8') {
|
||||||
|
return reinterpret_cast<double*>(data)[idx];
|
||||||
|
} else if (type[1] == 'i' && type[2] == '1') {
|
||||||
|
return reinterpret_cast<int8_t*>(data)[idx];
|
||||||
|
} else if (type[1] == 'i' && type[2] == '2') {
|
||||||
|
return reinterpret_cast<int16_t*>(data)[idx];
|
||||||
|
} else if (type[1] == 'i' && type[2] == '4') {
|
||||||
|
return reinterpret_cast<int32_t*>(data)[idx];
|
||||||
|
} else if (type[1] == 'i' && type[2] == '8') {
|
||||||
|
return reinterpret_cast<int64_t*>(data)[idx];
|
||||||
|
} else if (type[1] == 'u' && type[2] == '1') {
|
||||||
|
return reinterpret_cast<uint8_t*>(data)[idx];
|
||||||
|
} else if (type[1] == 'u' && type[2] == '2') {
|
||||||
|
return reinterpret_cast<uint16_t*>(data)[idx];
|
||||||
|
} else if (type[1] == 'u' && type[2] == '4') {
|
||||||
|
return reinterpret_cast<uint32_t*>(data)[idx];
|
||||||
|
} else if (type[1] == 'u' && type[2] == '8') {
|
||||||
|
return reinterpret_cast<uint64_t*>(data)[idx];
|
||||||
|
} else {
|
||||||
|
SPAN_CHECK(false);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RBitField8 valid;
|
||||||
|
int32_t size;
|
||||||
|
void* data;
|
||||||
|
char type[3];
|
||||||
};
|
};
|
||||||
|
|
||||||
#define DISPATCH_TYPE(__dispatched_func, __typestr, ...) { \
|
|
||||||
CHECK_EQ(__typestr.size(), 3) << ColumnarErrors::TypestrFormat(); \
|
|
||||||
if (__typestr.at(1) == 'f' && __typestr.at(2) == '4') { \
|
|
||||||
__dispatched_func<float>(__VA_ARGS__); \
|
|
||||||
} else if (__typestr.at(1) == 'f' && __typestr.at(2) == '8') { \
|
|
||||||
__dispatched_func<double>(__VA_ARGS__); \
|
|
||||||
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '1') { \
|
|
||||||
__dispatched_func<int8_t>(__VA_ARGS__); \
|
|
||||||
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '2') { \
|
|
||||||
__dispatched_func<int16_t>(__VA_ARGS__); \
|
|
||||||
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '4') { \
|
|
||||||
__dispatched_func<int32_t>(__VA_ARGS__); \
|
|
||||||
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '8') { \
|
|
||||||
__dispatched_func<int64_t>(__VA_ARGS__); \
|
|
||||||
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '1') { \
|
|
||||||
__dispatched_func<uint8_t>(__VA_ARGS__); \
|
|
||||||
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '2') { \
|
|
||||||
__dispatched_func<uint16_t>(__VA_ARGS__); \
|
|
||||||
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '4') { \
|
|
||||||
__dispatched_func<uint32_t>(__VA_ARGS__); \
|
|
||||||
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '8') { \
|
|
||||||
__dispatched_func<uint64_t>(__VA_ARGS__); \
|
|
||||||
} else { \
|
|
||||||
LOG(FATAL) << ColumnarErrors::UnSupportedType(__typestr); \
|
|
||||||
} \
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_DATA_COLUMNAR_H_
|
#endif // XGBOOST_DATA_COLUMNAR_H_
|
||||||
|
|||||||
@ -12,7 +12,6 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<float>* out) {
|
void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<float>* out) {
|
||||||
auto SetDeviceToPtr = [](void* ptr) {
|
auto SetDeviceToPtr = [](void* ptr) {
|
||||||
cudaPointerAttributes attr;
|
cudaPointerAttributes attr;
|
||||||
@ -21,17 +20,17 @@ void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<fl
|
|||||||
dh::safe_cuda(cudaSetDevice(ptr_device));
|
dh::safe_cuda(cudaSetDevice(ptr_device));
|
||||||
return ptr_device;
|
return ptr_device;
|
||||||
};
|
};
|
||||||
|
Columnar foreign_column(column);
|
||||||
|
auto ptr_device = SetDeviceToPtr(foreign_column.data);
|
||||||
|
|
||||||
common::Span<T> s_data { ArrayInterfaceHandler::ExtractData<T>(column) };
|
|
||||||
auto ptr_device = SetDeviceToPtr(s_data.data());
|
|
||||||
thrust::device_ptr<T> p_src {s_data.data()};
|
|
||||||
|
|
||||||
auto length = s_data.size();
|
|
||||||
out->SetDevice(ptr_device);
|
out->SetDevice(ptr_device);
|
||||||
out->Resize(length);
|
out->Resize(foreign_column.size);
|
||||||
|
|
||||||
auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
|
auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
|
||||||
thrust::copy(p_src, p_src + length, p_dst);
|
|
||||||
|
dh::LaunchN(ptr_device, foreign_column.size, [=] __device__(size_t idx) {
|
||||||
|
p_dst[idx] = foreign_column.GetElement(idx);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||||
@ -46,14 +45,12 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
|||||||
}
|
}
|
||||||
auto const& typestr = get<String const>(j_arr_obj.at("typestr"));
|
auto const& typestr = get<String const>(j_arr_obj.at("typestr"));
|
||||||
|
|
||||||
if (key == "root_index") {
|
if (key == "label") {
|
||||||
LOG(FATAL) << "root index for columnar data is not supported.";
|
CopyInfoImpl(j_arr_obj, &labels_);
|
||||||
} else if (key == "label") {
|
|
||||||
DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &labels_);
|
|
||||||
} else if (key == "weight") {
|
} else if (key == "weight") {
|
||||||
DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &weights_);
|
CopyInfoImpl(j_arr_obj, &weights_);
|
||||||
} else if (key == "base_margin") {
|
} else if (key == "base_margin") {
|
||||||
DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &base_margin_);
|
CopyInfoImpl(j_arr_obj, &base_margin_);
|
||||||
} else if (key == "group") {
|
} else if (key == "group") {
|
||||||
// Ranking is not performed on device.
|
// Ranking is not performed on device.
|
||||||
auto s_data = ArrayInterfaceHandler::ExtractData<uint32_t>(j_arr_obj);
|
auto s_data = ArrayInterfaceHandler::ExtractData<uint32_t>(j_arr_obj);
|
||||||
|
|||||||
@ -26,8 +26,7 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
template <typename T>
|
__global__ void CountValidKernel(Columnar const column,
|
||||||
__global__ void CountValidKernel(Columnar<T> const column,
|
|
||||||
bool has_missing, float missing,
|
bool has_missing, float missing,
|
||||||
int32_t* flag, common::Span<bst_row_t> offsets) {
|
int32_t* flag, common::Span<bst_row_t> offsets) {
|
||||||
auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
|
auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||||
@ -40,18 +39,18 @@ __global__ void CountValidKernel(Columnar<T> const column,
|
|||||||
|
|
||||||
if (!has_missing) {
|
if (!has_missing) {
|
||||||
if ((mask.Data() == nullptr || mask.Check(tid)) &&
|
if ((mask.Data() == nullptr || mask.Check(tid)) &&
|
||||||
!common::CheckNAN(column.data[tid])) {
|
!common::CheckNAN(column.GetElement(tid))) {
|
||||||
offsets[tid+1] += 1;
|
offsets[tid+1] += 1;
|
||||||
}
|
}
|
||||||
} else if (missing_is_nan) {
|
} else if (missing_is_nan) {
|
||||||
if (!common::CheckNAN(column.data[tid])) {
|
if (!common::CheckNAN(column.GetElement(tid))) {
|
||||||
offsets[tid+1] += 1;
|
offsets[tid+1] += 1;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!common::CloseTo(column.data[tid], missing)) {
|
if (!common::CloseTo(column.GetElement(tid), missing)) {
|
||||||
offsets[tid+1] += 1;
|
offsets[tid+1] += 1;
|
||||||
}
|
}
|
||||||
if (common::CheckNAN(column.data[tid])) {
|
if (common::CheckNAN(column.GetElement(tid))) {
|
||||||
*flag = 1;
|
*flag = 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -67,8 +66,7 @@ __device__ void AssignValue(T fvalue, int32_t colid,
|
|||||||
out_offsets[tid] += 1;
|
out_offsets[tid] += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
__global__ void CreateCSRKernel(Columnar const column,
|
||||||
__global__ void CreateCSRKernel(Columnar<T> const column,
|
|
||||||
int32_t colid, bool has_missing, float missing,
|
int32_t colid, bool has_missing, float missing,
|
||||||
common::Span<bst_row_t> offsets, common::Span<Entry> out_data) {
|
common::Span<bst_row_t> offsets, common::Span<Entry> out_data) {
|
||||||
auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
|
auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
|
||||||
@ -79,23 +77,22 @@ __global__ void CreateCSRKernel(Columnar<T> const column,
|
|||||||
if (!has_missing) {
|
if (!has_missing) {
|
||||||
// no missing value is specified
|
// no missing value is specified
|
||||||
if ((column.valid.Data() == nullptr || column.valid.Check(tid)) &&
|
if ((column.valid.Data() == nullptr || column.valid.Check(tid)) &&
|
||||||
!common::CheckNAN(column.data[tid])) {
|
!common::CheckNAN(column.GetElement(tid))) {
|
||||||
AssignValue(column.data[tid], colid, offsets, out_data);
|
AssignValue(column.GetElement(tid), colid, offsets, out_data);
|
||||||
}
|
}
|
||||||
} else if (missing_is_nan) {
|
} else if (missing_is_nan) {
|
||||||
// specified missing value, but it's NaN
|
// specified missing value, but it's NaN
|
||||||
if (!common::CheckNAN(column.data[tid])) {
|
if (!common::CheckNAN(column.GetElement(tid))) {
|
||||||
AssignValue(column.data[tid], colid, offsets, out_data);
|
AssignValue(column.GetElement(tid), colid, offsets, out_data);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// specified missing value, and it's not NaN
|
// specified missing value, and it's not NaN
|
||||||
if (!common::CloseTo(column.data[tid], missing)) {
|
if (!common::CloseTo(column.GetElement(tid), missing)) {
|
||||||
AssignValue(column.data[tid], colid, offsets, out_data);
|
AssignValue(column.GetElement(tid), colid, offsets, out_data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
|
void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
|
||||||
bool has_missing, float missing,
|
bool has_missing, float missing,
|
||||||
HostDeviceVector<bst_row_t>* out_offset,
|
HostDeviceVector<bst_row_t>* out_offset,
|
||||||
@ -104,10 +101,10 @@ void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
|
|||||||
uint32_t constexpr kThreads = 256;
|
uint32_t constexpr kThreads = 256;
|
||||||
auto const& j_column = j_columns[column_id];
|
auto const& j_column = j_columns[column_id];
|
||||||
auto const& column_obj = get<Object const>(j_column);
|
auto const& column_obj = get<Object const>(j_column);
|
||||||
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj);
|
Columnar foreign_column(column_obj);
|
||||||
uint32_t const n_rows = foreign_column.size;
|
uint32_t const n_rows = foreign_column.size;
|
||||||
|
|
||||||
auto ptr = foreign_column.data.data();
|
auto ptr = foreign_column.data;
|
||||||
int32_t device = dh::CudaGetPointerDevice(ptr);
|
int32_t device = dh::CudaGetPointerDevice(ptr);
|
||||||
CHECK_NE(device, -1);
|
CHECK_NE(device, -1);
|
||||||
dh::safe_cuda(cudaSetDevice(device));
|
dh::safe_cuda(cudaSetDevice(device));
|
||||||
@ -125,24 +122,23 @@ void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
|
|||||||
|
|
||||||
uint32_t const kBlocks = common::DivRoundUp(n_rows, kThreads);
|
uint32_t const kBlocks = common::DivRoundUp(n_rows, kThreads);
|
||||||
dh::LaunchKernel {kBlocks, kThreads} (
|
dh::LaunchKernel {kBlocks, kThreads} (
|
||||||
CountValidKernel<T>,
|
CountValidKernel,
|
||||||
foreign_column,
|
foreign_column,
|
||||||
has_missing, missing,
|
has_missing, missing,
|
||||||
out_d_flag->data().get(), s_offsets);
|
out_d_flag->data().get(), s_offsets);
|
||||||
*out_n_rows = n_rows;
|
*out_n_rows = n_rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows,
|
void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows,
|
||||||
bool has_missing, float missing,
|
bool has_missing, float missing,
|
||||||
dh::device_vector<bst_row_t>* tmp_offset, common::Span<Entry> s_data) {
|
dh::device_vector<bst_row_t>* tmp_offset, common::Span<Entry> s_data) {
|
||||||
uint32_t constexpr kThreads = 256;
|
uint32_t constexpr kThreads = 256;
|
||||||
auto const& j_column = j_columns[column_id];
|
auto const& j_column = j_columns[column_id];
|
||||||
auto const& column_obj = get<Object const>(j_column);
|
auto const& column_obj = get<Object const>(j_column);
|
||||||
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj);
|
Columnar foreign_column(column_obj);
|
||||||
uint32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
|
uint32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
|
||||||
dh::LaunchKernel {kBlocks, kThreads} (
|
dh::LaunchKernel {kBlocks, kThreads} (
|
||||||
CreateCSRKernel<T>,
|
CreateCSRKernel,
|
||||||
foreign_column, column_id, has_missing, missing,
|
foreign_column, column_id, has_missing, missing,
|
||||||
dh::ToSpan(*tmp_offset), s_data);
|
dh::ToSpan(*tmp_offset), s_data);
|
||||||
}
|
}
|
||||||
@ -159,9 +155,8 @@ void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns,
|
|||||||
}
|
}
|
||||||
uint32_t n_rows {0};
|
uint32_t n_rows {0};
|
||||||
for (size_t i = 0; i < n_cols; ++i) {
|
for (size_t i = 0; i < n_cols; ++i) {
|
||||||
auto const& typestr = get<String const>(columns[i]["typestr"]);
|
CountValid(columns, i, has_missing, missing, &(this->page_.offset), &d_flag,
|
||||||
DISPATCH_TYPE(CountValid, typestr,
|
&n_rows);
|
||||||
columns, i, has_missing, missing, &(this->page_.offset), &d_flag, &n_rows);
|
|
||||||
}
|
}
|
||||||
// don't pay for what you don't use.
|
// don't pay for what you don't use.
|
||||||
if (!common::CheckNAN(missing)) {
|
if (!common::CheckNAN(missing)) {
|
||||||
@ -197,9 +192,7 @@ void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns,
|
|||||||
|
|
||||||
int32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
|
int32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
|
||||||
for (size_t i = 0; i < n_cols; ++i) {
|
for (size_t i = 0; i < n_cols; ++i) {
|
||||||
auto const& typestr = get<String const>(columns[i]["typestr"]);
|
CreateCSR(columns, i, n_rows, has_missing, missing, &tmp_offset, s_data);
|
||||||
DISPATCH_TYPE(CreateCSR, typestr, columns, i, n_rows,
|
|
||||||
has_missing, missing, &tmp_offset, s_data);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -23,21 +23,21 @@ TEST(ArrayInterfaceHandler, Error) {
|
|||||||
|
|
||||||
auto const& column_obj = get<Object>(column);
|
auto const& column_obj = get<Object>(column);
|
||||||
// missing version
|
// missing version
|
||||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
|
||||||
column["version"] = Integer(static_cast<Integer::Int>(1));
|
column["version"] = Integer(static_cast<Integer::Int>(1));
|
||||||
// missing data
|
// missing data
|
||||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
|
||||||
column["data"] = j_data;
|
column["data"] = j_data;
|
||||||
// missing typestr
|
// missing typestr
|
||||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
|
||||||
column["typestr"] = String("<f4");
|
column["typestr"] = String("<f4");
|
||||||
// nullptr is not valid
|
// nullptr is not valid
|
||||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
|
||||||
thrust::device_vector<float> d_data(kRows);
|
thrust::device_vector<float> d_data(kRows);
|
||||||
j_data = {Json(Integer(reinterpret_cast<Integer::Int>(d_data.data().get()))),
|
j_data = {Json(Integer(reinterpret_cast<Integer::Int>(d_data.data().get()))),
|
||||||
Json(Boolean(false))};
|
Json(Boolean(false))};
|
||||||
column["data"] = j_data;
|
column["data"] = j_data;
|
||||||
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj));
|
EXPECT_NO_THROW(Columnar c(column_obj));
|
||||||
|
|
||||||
std::vector<Json> j_mask_shape {Json(Integer(static_cast<Integer::Int>(kRows - 1)))};
|
std::vector<Json> j_mask_shape {Json(Integer(static_cast<Integer::Int>(kRows - 1)))};
|
||||||
column["mask"] = Object();
|
column["mask"] = Object();
|
||||||
@ -46,7 +46,7 @@ TEST(ArrayInterfaceHandler, Error) {
|
|||||||
column["mask"]["typestr"] = String("<i1");
|
column["mask"]["typestr"] = String("<i1");
|
||||||
column["mask"]["version"] = Integer(static_cast<Integer::Int>(1));
|
column["mask"]["version"] = Integer(static_cast<Integer::Int>(1));
|
||||||
// shape of mask and data doesn't match.
|
// shape of mask and data doesn't match.
|
||||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error);
|
EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -75,6 +75,23 @@ Json GenerateDenseColumn(std::string const& typestr, size_t kRows,
|
|||||||
return column;
|
return column;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestGetElement() {
|
||||||
|
thrust::device_vector<float> data;
|
||||||
|
auto j_column = GenerateDenseColumn("<f4", 3, &data);
|
||||||
|
auto const& column_obj = get<Object const>(j_column);
|
||||||
|
Columnar foreign_column(column_obj);
|
||||||
|
|
||||||
|
EXPECT_NO_THROW({
|
||||||
|
dh::LaunchN(0, 1, [=] __device__(size_t idx) {
|
||||||
|
KERNEL_CHECK(foreign_column.GetElement(0) == 0.0f);
|
||||||
|
KERNEL_CHECK(foreign_column.GetElement(1) == 2.0f);
|
||||||
|
KERNEL_CHECK(foreign_column.GetElement(2) == 4.0f);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Columnar, GetElement) { TestGetElement(); }
|
||||||
|
|
||||||
void TestDenseColumn(std::unique_ptr<data::SimpleCSRSource> const& source,
|
void TestDenseColumn(std::unique_ptr<data::SimpleCSRSource> const& source,
|
||||||
size_t n_rows, size_t n_cols) {
|
size_t n_rows, size_t n_cols) {
|
||||||
auto const& data = source->page_.data.HostVector();
|
auto const& data = source->page_.data.HostVector();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user