Accept string for ArrayInterface constructor.
This commit is contained in:
parent
b47b5ac771
commit
e8ecafb8dc
@ -1,7 +1,7 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2019 by Contributors
|
* Copyright 2019 by Contributors
|
||||||
* \file array_interface.h
|
* \file array_interface.h
|
||||||
* \brief Basic structure holding a reference to arrow columnar data format.
|
* \brief View of __array_interface__
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
|
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||||
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||||
@ -11,6 +11,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
@ -113,6 +114,7 @@ class ArrayInterfaceHandler {
|
|||||||
get<Array const>(
|
get<Array const>(
|
||||||
obj.at("data"))
|
obj.at("data"))
|
||||||
.at(0))));
|
.at(0))));
|
||||||
|
CHECK(p_data);
|
||||||
return p_data;
|
return p_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,7 +188,7 @@ class ArrayInterfaceHandler {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::pair<size_t, size_t> ExtractShape(
|
static std::pair<bst_row_t, bst_feature_t> ExtractShape(
|
||||||
std::map<std::string, Json> const& column) {
|
std::map<std::string, Json> const& column) {
|
||||||
auto j_shape = get<Array const>(column.at("shape"));
|
auto j_shape = get<Array const>(column.at("shape"));
|
||||||
auto typestr = get<String const>(column.at("typestr"));
|
auto typestr = get<String const>(column.at("typestr"));
|
||||||
@ -201,12 +203,12 @@ class ArrayInterfaceHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (j_shape.size() == 1) {
|
if (j_shape.size() == 1) {
|
||||||
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))), 1};
|
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1};
|
||||||
} else {
|
} else {
|
||||||
CHECK_EQ(j_shape.size(), 2)
|
CHECK_EQ(j_shape.size(), 2)
|
||||||
<< "Only 1D or 2-D arrays currently supported.";
|
<< "Only 1D or 2-D arrays currently supported.";
|
||||||
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))),
|
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))),
|
||||||
static_cast<size_t>(get<Integer const>(j_shape.at(1)))};
|
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -219,7 +221,6 @@ class ArrayInterfaceHandler {
|
|||||||
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
|
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
|
||||||
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
||||||
|
|
||||||
|
|
||||||
auto shape = ExtractShape(column);
|
auto shape = ExtractShape(column);
|
||||||
|
|
||||||
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
||||||
@ -231,8 +232,8 @@ class ArrayInterfaceHandler {
|
|||||||
class ArrayInterface {
|
class ArrayInterface {
|
||||||
public:
|
public:
|
||||||
ArrayInterface() = default;
|
ArrayInterface() = default;
|
||||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
void Initialize(std::map<std::string, Json> const &column,
|
||||||
bool allow_mask = true) {
|
bool allow_mask = true) {
|
||||||
ArrayInterfaceHandler::Validate(column);
|
ArrayInterfaceHandler::Validate(column);
|
||||||
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
||||||
CHECK(data) << "Column is null";
|
CHECK(data) << "Column is null";
|
||||||
@ -263,6 +264,25 @@ class ArrayInterface {
|
|||||||
this->CheckType();
|
this->CheckType();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
|
||||||
|
auto jinterface = Json::Load({str.c_str(), str.size()});
|
||||||
|
if (IsA<Object>(jinterface)) {
|
||||||
|
this->Initialize(get<Object const>(jinterface), allow_mask);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (IsA<Array>(jinterface)) {
|
||||||
|
CHECK_EQ(get<Array const>(jinterface).size(), 1)
|
||||||
|
<< "Column: " << ArrayInterfaceErrors::Dimension(1);
|
||||||
|
this->Initialize(get<Object const>(get<Array const>(jinterface)[0]), allow_mask);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||||
|
bool allow_mask = true) {
|
||||||
|
this->Initialize(column, allow_mask);
|
||||||
|
}
|
||||||
|
|
||||||
void CheckType() const {
|
void CheckType() const {
|
||||||
if (type[1] == 'f' && type[2] == '4') {
|
if (type[1] == 'f' && type[2] == '4') {
|
||||||
return;
|
return;
|
||||||
@ -291,6 +311,7 @@ class ArrayInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE float GetElement(size_t idx) const {
|
XGBOOST_DEVICE float GetElement(size_t idx) const {
|
||||||
|
SPAN_CHECK(idx < num_cols * num_rows);
|
||||||
if (type[1] == 'f' && type[2] == '4') {
|
if (type[1] == 'f' && type[2] == '4') {
|
||||||
return reinterpret_cast<float*>(data)[idx];
|
return reinterpret_cast<float*>(data)[idx];
|
||||||
} else if (type[1] == 'f' && type[2] == '8') {
|
} else if (type[1] == 'f' && type[2] == '8') {
|
||||||
@ -318,8 +339,8 @@ class ArrayInterface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
RBitField8 valid;
|
RBitField8 valid;
|
||||||
int32_t num_rows;
|
bst_row_t num_rows;
|
||||||
int32_t num_cols;
|
bst_feature_t num_cols;
|
||||||
void* data;
|
void* data;
|
||||||
char type[3];
|
char type[3];
|
||||||
};
|
};
|
||||||
|
|||||||
@ -63,7 +63,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
|||||||
auto const& j_arr = get<Array>(j_interface);
|
auto const& j_arr = get<Array>(j_interface);
|
||||||
CHECK_EQ(j_arr.size(), 1)
|
CHECK_EQ(j_arr.size(), 1)
|
||||||
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
|
<< "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1);
|
||||||
ArrayInterface array_interface(get<Object const>(j_arr[0]));
|
ArrayInterface array_interface(interface_str);
|
||||||
std::string key{c_key};
|
std::string key{c_key};
|
||||||
CHECK(!array_interface.valid.Data())
|
CHECK(!array_interface.valid.Data())
|
||||||
<< "Meta info " << key << " should be dense, found validity mask";
|
<< "Meta info " << key << " should be dense, found validity mask";
|
||||||
|
|||||||
51
tests/cpp/data/test_array_interface.cc
Normal file
51
tests/cpp/data/test_array_interface.cc
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2020 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/host_device_vector.h>
|
||||||
|
#include "../helpers.h"
|
||||||
|
#include "../../../src/data/array_interface.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
TEST(ArrayInterface, Initialize) {
|
||||||
|
size_t constexpr kRows = 10, kCols = 10;
|
||||||
|
HostDeviceVector<float> storage;
|
||||||
|
auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
|
||||||
|
auto arr_interface = ArrayInterface(array);
|
||||||
|
ASSERT_EQ(arr_interface.num_rows, kRows);
|
||||||
|
ASSERT_EQ(arr_interface.num_cols, kCols);
|
||||||
|
ASSERT_EQ(arr_interface.data, storage.ConstHostPointer());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ArrayInterface, Error) {
|
||||||
|
constexpr size_t kRows = 16, kCols = 10;
|
||||||
|
Json column { Object() };
|
||||||
|
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
|
||||||
|
column["shape"] = Array(j_shape);
|
||||||
|
std::vector<Json> j_data {
|
||||||
|
Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
|
||||||
|
Json(Boolean(false))};
|
||||||
|
|
||||||
|
auto const& column_obj = get<Object>(column);
|
||||||
|
// missing version
|
||||||
|
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
|
||||||
|
column["version"] = Integer(static_cast<Integer::Int>(1));
|
||||||
|
// missing data
|
||||||
|
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
|
||||||
|
column["data"] = j_data;
|
||||||
|
// missing typestr
|
||||||
|
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
|
||||||
|
column["typestr"] = String("<f4");
|
||||||
|
// nullptr is not valid
|
||||||
|
EXPECT_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj), dmlc::Error);
|
||||||
|
|
||||||
|
HostDeviceVector<float> storage;
|
||||||
|
auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
|
||||||
|
j_data = {
|
||||||
|
Json(Integer(reinterpret_cast<Integer::Int>(storage.ConstHostPointer()))),
|
||||||
|
Json(Boolean(false))};
|
||||||
|
column["data"] = j_data;
|
||||||
|
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData<float>(column_obj));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xgboost
|
||||||
@ -182,7 +182,13 @@ Json RandomDataGenerator::ArrayInterfaceImpl(HostDeviceVector<float> *storage,
|
|||||||
this->GenerateDense(storage);
|
this->GenerateDense(storage);
|
||||||
Json array_interface {Object()};
|
Json array_interface {Object()};
|
||||||
array_interface["data"] = std::vector<Json>(2);
|
array_interface["data"] = std::vector<Json>(2);
|
||||||
array_interface["data"][0] = Integer(reinterpret_cast<int64_t>(storage->DevicePointer()));
|
if (storage->DeviceCanRead()) {
|
||||||
|
array_interface["data"][0] =
|
||||||
|
Integer(reinterpret_cast<int64_t>(storage->ConstDevicePointer()));
|
||||||
|
} else {
|
||||||
|
array_interface["data"][0] =
|
||||||
|
Integer(reinterpret_cast<int64_t>(storage->ConstHostPointer()));
|
||||||
|
}
|
||||||
array_interface["data"][1] = Boolean(false);
|
array_interface["data"][1] = Boolean(false);
|
||||||
|
|
||||||
array_interface["shape"] = std::vector<Json>(2);
|
array_interface["shape"] = std::vector<Json>(2);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user