Use generic dispatching routine for array interface. (#6672)
This commit is contained in:
parent
a4101de678
commit
1e949110da
@ -315,40 +315,50 @@ class ArrayInterface {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
|
template <typename Fn>
|
||||||
void* p_values;
|
XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case kF4:
|
case kF4:
|
||||||
p_values = reinterpret_cast<float *>(data) + offset;
|
return func(reinterpret_cast<float *>(data));
|
||||||
break;
|
break;
|
||||||
case kF8:
|
case kF8:
|
||||||
p_values = reinterpret_cast<double *>(data) + offset;
|
return func(reinterpret_cast<double *>(data));
|
||||||
break;
|
break;
|
||||||
case kI1:
|
case kI1:
|
||||||
p_values = reinterpret_cast<int8_t *>(data) + offset;
|
return func(reinterpret_cast<int8_t *>(data));
|
||||||
break;
|
break;
|
||||||
case kI2:
|
case kI2:
|
||||||
p_values = reinterpret_cast<int16_t *>(data) + offset;
|
return func(reinterpret_cast<int16_t *>(data));
|
||||||
break;
|
break;
|
||||||
case kI4:
|
case kI4:
|
||||||
p_values = reinterpret_cast<int32_t *>(data) + offset;
|
return func(reinterpret_cast<int32_t *>(data));
|
||||||
break;
|
break;
|
||||||
case kI8:
|
case kI8:
|
||||||
p_values = reinterpret_cast<int64_t *>(data) + offset;
|
return func(reinterpret_cast<int64_t *>(data));
|
||||||
break;
|
break;
|
||||||
case kU1:
|
case kU1:
|
||||||
p_values = reinterpret_cast<uint8_t *>(data) + offset;
|
return func(reinterpret_cast<uint8_t *>(data));
|
||||||
break;
|
break;
|
||||||
case kU2:
|
case kU2:
|
||||||
p_values = reinterpret_cast<uint16_t *>(data) + offset;
|
return func(reinterpret_cast<uint16_t *>(data));
|
||||||
break;
|
break;
|
||||||
case kU4:
|
case kU4:
|
||||||
p_values = reinterpret_cast<uint32_t *>(data) + offset;
|
return func(reinterpret_cast<uint32_t *>(data));
|
||||||
break;
|
break;
|
||||||
case kU8:
|
case kU8:
|
||||||
p_values = reinterpret_cast<uint64_t *>(data) + offset;
|
return func(reinterpret_cast<uint64_t *>(data));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
SPAN_CHECK(false);
|
||||||
|
return func(reinterpret_cast<uint64_t *>(data));
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
|
||||||
|
void* p_values{nullptr};
|
||||||
|
this->DispatchCall([&p_values, offset](auto *ptr) {
|
||||||
|
p_values = ptr + offset;
|
||||||
|
});
|
||||||
|
|
||||||
ArrayInterface ret = *this;
|
ArrayInterface ret = *this;
|
||||||
ret.data = p_values;
|
ret.data = p_values;
|
||||||
return ret;
|
return ret;
|
||||||
@ -390,6 +400,12 @@ class ArrayInterface {
|
|||||||
return reinterpret_cast<float*>(data)[idx];
|
return reinterpret_cast<float*>(data)[idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE size_t ElementSize() {
|
||||||
|
return this->DispatchCall([](auto* p_values) {
|
||||||
|
return sizeof(std::remove_pointer_t<decltype(p_values)>);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
RBitField8 valid;
|
RBitField8 valid;
|
||||||
bst_row_t num_rows;
|
bst_row_t num_rows;
|
||||||
bst_feature_t num_cols;
|
bst_feature_t num_cols;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2020 by XGBoost Contributors
|
* Copyright 2020-2021 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
@ -15,6 +15,17 @@ TEST(ArrayInterface, Initialize) {
|
|||||||
ASSERT_EQ(arr_interface.num_rows, kRows);
|
ASSERT_EQ(arr_interface.num_rows, kRows);
|
||||||
ASSERT_EQ(arr_interface.num_cols, kCols);
|
ASSERT_EQ(arr_interface.num_cols, kCols);
|
||||||
ASSERT_EQ(arr_interface.data, storage.ConstHostPointer());
|
ASSERT_EQ(arr_interface.data, storage.ConstHostPointer());
|
||||||
|
ASSERT_EQ(arr_interface.ElementSize(), 4);
|
||||||
|
ASSERT_EQ(arr_interface.type, ArrayInterface::kF4);
|
||||||
|
|
||||||
|
HostDeviceVector<size_t> u64_storage(storage.Size());
|
||||||
|
std::string u64_arr_str;
|
||||||
|
Json::Dump(GetArrayInterface(&u64_storage, kRows, kCols), &u64_arr_str);
|
||||||
|
std::copy(storage.ConstHostVector().cbegin(), storage.ConstHostVector().cend(),
|
||||||
|
u64_storage.HostSpan().begin());
|
||||||
|
auto u64_arr = ArrayInterface{u64_arr_str};
|
||||||
|
ASSERT_EQ(u64_arr.ElementSize(), 8);
|
||||||
|
ASSERT_EQ(u64_arr.type, ArrayInterface::kU8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ArrayInterface, Error) {
|
TEST(ArrayInterface, Error) {
|
||||||
|
|||||||
@ -190,24 +190,7 @@ void RandomDataGenerator::GenerateDense(HostDeviceVector<float> *out) const {
|
|||||||
Json RandomDataGenerator::ArrayInterfaceImpl(HostDeviceVector<float> *storage,
|
Json RandomDataGenerator::ArrayInterfaceImpl(HostDeviceVector<float> *storage,
|
||||||
size_t rows, size_t cols) const {
|
size_t rows, size_t cols) const {
|
||||||
this->GenerateDense(storage);
|
this->GenerateDense(storage);
|
||||||
Json array_interface {Object()};
|
return GetArrayInterface(storage, rows, cols);
|
||||||
array_interface["data"] = std::vector<Json>(2);
|
|
||||||
if (storage->DeviceCanRead()) {
|
|
||||||
array_interface["data"][0] =
|
|
||||||
Integer(reinterpret_cast<int64_t>(storage->ConstDevicePointer()));
|
|
||||||
} else {
|
|
||||||
array_interface["data"][0] =
|
|
||||||
Integer(reinterpret_cast<int64_t>(storage->ConstHostPointer()));
|
|
||||||
}
|
|
||||||
array_interface["data"][1] = Boolean(false);
|
|
||||||
|
|
||||||
array_interface["shape"] = std::vector<Json>(2);
|
|
||||||
array_interface["shape"][0] = rows;
|
|
||||||
array_interface["shape"][1] = cols;
|
|
||||||
|
|
||||||
array_interface["typestr"] = String("<f4");
|
|
||||||
array_interface["version"] = 1;
|
|
||||||
return array_interface;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string RandomDataGenerator::GenerateArrayInterface(
|
std::string RandomDataGenerator::GenerateArrayInterface(
|
||||||
|
|||||||
@ -22,6 +22,7 @@
|
|||||||
|
|
||||||
#include "../../src/common/common.h"
|
#include "../../src/common/common.h"
|
||||||
#include "../../src/gbm/gbtree_model.h"
|
#include "../../src/gbm/gbtree_model.h"
|
||||||
|
#include "../../src/data/array_interface.h"
|
||||||
|
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
#define DeclareUnifiedTest(name) GPU ## name
|
#define DeclareUnifiedTest(name) GPU ## name
|
||||||
@ -181,6 +182,29 @@ class SimpleRealUniformDistribution {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Json GetArrayInterface(HostDeviceVector<T> *storage, size_t rows, size_t cols) {
|
||||||
|
Json array_interface{Object()};
|
||||||
|
array_interface["data"] = std::vector<Json>(2);
|
||||||
|
if (storage->DeviceCanRead()) {
|
||||||
|
array_interface["data"][0] =
|
||||||
|
Integer(reinterpret_cast<int64_t>(storage->ConstDevicePointer()));
|
||||||
|
} else {
|
||||||
|
array_interface["data"][0] =
|
||||||
|
Integer(reinterpret_cast<int64_t>(storage->ConstHostPointer()));
|
||||||
|
}
|
||||||
|
array_interface["data"][1] = Boolean(false);
|
||||||
|
|
||||||
|
array_interface["shape"] = std::vector<Json>(2);
|
||||||
|
array_interface["shape"][0] = rows;
|
||||||
|
array_interface["shape"][1] = cols;
|
||||||
|
|
||||||
|
char t = ArrayInterfaceHandler::TypeChar<T>();
|
||||||
|
array_interface["typestr"] = String(std::string{"<"} + t + std::to_string(sizeof(T)));
|
||||||
|
array_interface["version"] = 1;
|
||||||
|
return array_interface;
|
||||||
|
}
|
||||||
|
|
||||||
// Generate in-memory random data without using DMatrix.
|
// Generate in-memory random data without using DMatrix.
|
||||||
class RandomDataGenerator {
|
class RandomDataGenerator {
|
||||||
bst_row_t rows_;
|
bst_row_t rows_;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user