Accept string for ArrayInterface constructor.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
/*!
|
||||
* Copyright 2019 by Contributors
|
||||
* \file array_interface.h
|
||||
* \brief Basic structure holding a reference to arrow columnar data format.
|
||||
* \brief View of __array_interface__
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/logging.h"
|
||||
@@ -113,6 +114,7 @@ class ArrayInterfaceHandler {
|
||||
get<Array const>(
|
||||
obj.at("data"))
|
||||
.at(0))));
|
||||
CHECK(p_data);
|
||||
return p_data;
|
||||
}
|
||||
|
||||
@@ -186,7 +188,7 @@ class ArrayInterfaceHandler {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::pair<size_t, size_t> ExtractShape(
|
||||
static std::pair<bst_row_t, bst_feature_t> ExtractShape(
|
||||
std::map<std::string, Json> const& column) {
|
||||
auto j_shape = get<Array const>(column.at("shape"));
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
@@ -201,12 +203,12 @@ class ArrayInterfaceHandler {
|
||||
}
|
||||
|
||||
if (j_shape.size() == 1) {
|
||||
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))), 1};
|
||||
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1};
|
||||
} else {
|
||||
CHECK_EQ(j_shape.size(), 2)
|
||||
<< "Only 1D or 2-D arrays currently supported.";
|
||||
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))),
|
||||
static_cast<size_t>(get<Integer const>(j_shape.at(1)))};
|
||||
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))),
|
||||
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
@@ -219,7 +221,6 @@ class ArrayInterfaceHandler {
|
||||
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
|
||||
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
||||
|
||||
|
||||
auto shape = ExtractShape(column);
|
||||
|
||||
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
||||
@@ -231,8 +232,8 @@ class ArrayInterfaceHandler {
|
||||
class ArrayInterface {
|
||||
public:
|
||||
ArrayInterface() = default;
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
void Initialize(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
ArrayInterfaceHandler::Validate(column);
|
||||
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
||||
CHECK(data) << "Column is null";
|
||||
@@ -263,6 +264,25 @@ class ArrayInterface {
|
||||
this->CheckType();
|
||||
}
|
||||
|
||||
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
|
||||
auto jinterface = Json::Load({str.c_str(), str.size()});
|
||||
if (IsA<Object>(jinterface)) {
|
||||
this->Initialize(get<Object const>(jinterface), allow_mask);
|
||||
return;
|
||||
}
|
||||
if (IsA<Array>(jinterface)) {
|
||||
CHECK_EQ(get<Array const>(jinterface).size(), 1)
|
||||
<< "Column: " << ArrayInterfaceErrors::Dimension(1);
|
||||
this->Initialize(get<Object const>(get<Array const>(jinterface)[0]), allow_mask);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
this->Initialize(column, allow_mask);
|
||||
}
|
||||
|
||||
void CheckType() const {
|
||||
if (type[1] == 'f' && type[2] == '4') {
|
||||
return;
|
||||
@@ -291,6 +311,7 @@ class ArrayInterface {
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE float GetElement(size_t idx) const {
|
||||
SPAN_CHECK(idx < num_cols * num_rows);
|
||||
if (type[1] == 'f' && type[2] == '4') {
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
} else if (type[1] == 'f' && type[2] == '8') {
|
||||
@@ -318,8 +339,8 @@ class ArrayInterface {
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
int32_t num_rows;
|
||||
int32_t num_cols;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
void* data;
|
||||
char type[3];
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user