Revert "Accept string for ArrayInterface constructor."

This reverts commit e8ecafb8dc.
This commit is contained in:
fis
2020-06-16 20:02:35 +08:00
parent e8ecafb8dc
commit 7c3a168ffd
4 changed files with 12 additions and 90 deletions

View File

@@ -1,7 +1,7 @@
/*!
* Copyright 2019 by Contributors
* \file array_interface.h
* \brief View of __array_interface__
* \brief Basic structure holding a reference to arrow columnar data format.
*/
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
@@ -11,7 +11,6 @@
#include <string>
#include <utility>
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "xgboost/logging.h"
@@ -114,7 +113,6 @@ class ArrayInterfaceHandler {
get<Array const>(
obj.at("data"))
.at(0))));
CHECK(p_data);
return p_data;
}
@@ -188,7 +186,7 @@ class ArrayInterfaceHandler {
return 0;
}
static std::pair<bst_row_t, bst_feature_t> ExtractShape(
static std::pair<size_t, size_t> ExtractShape(
std::map<std::string, Json> const& column) {
auto j_shape = get<Array const>(column.at("shape"));
auto typestr = get<String const>(column.at("typestr"));
@@ -203,12 +201,12 @@ class ArrayInterfaceHandler {
}
if (j_shape.size() == 1) {
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1};
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))), 1};
} else {
CHECK_EQ(j_shape.size(), 2)
<< "Only 1D or 2-D arrays currently supported.";
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))),
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
return {static_cast<size_t>(get<Integer const>(j_shape.at(0))),
static_cast<size_t>(get<Integer const>(j_shape.at(1)))};
}
}
template <typename T>
@@ -221,6 +219,7 @@ class ArrayInterfaceHandler {
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
<< "Input data type and typestr mismatch. typestr: " << typestr;
auto shape = ExtractShape(column);
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
@@ -232,8 +231,8 @@ class ArrayInterfaceHandler {
class ArrayInterface {
public:
ArrayInterface() = default;
void Initialize(std::map<std::string, Json> const &column,
bool allow_mask = true) {
explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
CHECK(data) << "Column is null";
@@ -264,25 +263,6 @@ class ArrayInterface {
this->CheckType();
}
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
auto jinterface = Json::Load({str.c_str(), str.size()});
if (IsA<Object>(jinterface)) {
this->Initialize(get<Object const>(jinterface), allow_mask);
return;
}
if (IsA<Array>(jinterface)) {
CHECK_EQ(get<Array const>(jinterface).size(), 1)
<< "Column: " << ArrayInterfaceErrors::Dimension(1);
this->Initialize(get<Object const>(get<Array const>(jinterface)[0]), allow_mask);
return;
}
}
explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
this->Initialize(column, allow_mask);
}
void CheckType() const {
if (type[1] == 'f' && type[2] == '4') {
return;
@@ -311,7 +291,6 @@ class ArrayInterface {
}
XGBOOST_DEVICE float GetElement(size_t idx) const {
SPAN_CHECK(idx < num_cols * num_rows);
if (type[1] == 'f' && type[2] == '4') {
return reinterpret_cast<float*>(data)[idx];
} else if (type[1] == 'f' && type[2] == '8') {
@@ -339,8 +318,8 @@ class ArrayInterface {
}
RBitField8 valid;
bst_row_t num_rows;
bst_feature_t num_cols;
int32_t num_rows;
int32_t num_cols;
void* data;
char type[3];
};