From 7c3a168ffd8c519ce21592d2c348be3a89d68913 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 16 Jun 2020 20:02:35 +0800 Subject: [PATCH] Revert "Accept string for ArrayInterface constructor." This reverts commit e8ecafb8dc628f45b75b4c2844a236d27e0a6d98. --- src/data/array_interface.h | 41 +++++---------------- src/data/data.cu | 2 +- tests/cpp/data/test_array_interface.cc | 51 -------------------------- tests/cpp/helpers.cc | 8 +--- 4 files changed, 12 insertions(+), 90 deletions(-) delete mode 100644 tests/cpp/data/test_array_interface.cc diff --git a/src/data/array_interface.h b/src/data/array_interface.h index c4db6c5dc..53152b63d 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -1,7 +1,7 @@ /*! * Copyright 2019 by Contributors * \file array_interface.h - * \brief View of __array_interface__ + * \brief Basic structure holding a reference to arrow columnar data format. */ #ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_ #define XGBOOST_DATA_ARRAY_INTERFACE_H_ @@ -11,7 +11,6 @@ #include #include -#include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/json.h" #include "xgboost/logging.h" @@ -114,7 +113,6 @@ class ArrayInterfaceHandler { get( obj.at("data")) .at(0)))); - CHECK(p_data); return p_data; } @@ -188,7 +186,7 @@ class ArrayInterfaceHandler { return 0; } - static std::pair ExtractShape( + static std::pair ExtractShape( std::map const& column) { auto j_shape = get(column.at("shape")); auto typestr = get(column.at("typestr")); @@ -203,12 +201,12 @@ class ArrayInterfaceHandler { } if (j_shape.size() == 1) { - return {static_cast(get(j_shape.at(0))), 1}; + return {static_cast(get(j_shape.at(0))), 1}; } else { CHECK_EQ(j_shape.size(), 2) << "Only 1D or 2-D arrays currently supported."; - return {static_cast(get(j_shape.at(0))), - static_cast(get(j_shape.at(1)))}; + return {static_cast(get(j_shape.at(0))), + static_cast(get(j_shape.at(1)))}; } } template @@ -221,6 +219,7 @@ class ArrayInterfaceHandler { CHECK_EQ(typestr.at(2), static_cast(sizeof(T) + 48)) << "Input data type and typestr mismatch. typestr: " << typestr; + auto shape = ExtractShape(column); T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData(column); @@ -232,8 +231,8 @@ class ArrayInterfaceHandler { class ArrayInterface { public: ArrayInterface() = default; - void Initialize(std::map const &column, - bool allow_mask = true) { + explicit ArrayInterface(std::map const &column, + bool allow_mask = true) { ArrayInterfaceHandler::Validate(column); data = ArrayInterfaceHandler::GetPtrFromArrayData(column); CHECK(data) << "Column is null"; @@ -264,25 +263,6 @@ class ArrayInterface { this->CheckType(); } - explicit ArrayInterface(std::string const& str, bool allow_mask = true) { - auto jinterface = Json::Load({str.c_str(), str.size()}); - if (IsA(jinterface)) { - this->Initialize(get(jinterface), allow_mask); - return; - } - if (IsA(jinterface)) { - CHECK_EQ(get(jinterface).size(), 1) - << "Column: " << ArrayInterfaceErrors::Dimension(1); - this->Initialize(get(get(jinterface)[0]), allow_mask); - return; - } - } - - explicit ArrayInterface(std::map const &column, - bool allow_mask = true) { - this->Initialize(column, allow_mask); - } - void CheckType() const { if (type[1] == 'f' && type[2] == '4') { return; @@ -311,7 +291,6 @@ class ArrayInterface { } XGBOOST_DEVICE float GetElement(size_t idx) const { - SPAN_CHECK(idx < num_cols * num_rows); if (type[1] == 'f' && type[2] == '4') { return reinterpret_cast(data)[idx]; } else if (type[1] == 'f' && type[2] == '8') { @@ -339,8 +318,8 @@ class ArrayInterface { } RBitField8 valid; - bst_row_t num_rows; - bst_feature_t num_cols; + int32_t num_rows; + int32_t num_cols; void* data; char type[3]; }; diff --git a/src/data/data.cu b/src/data/data.cu index fb57f4751..526f9a673 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -63,7 +63,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { auto const& j_arr = get(j_interface); CHECK_EQ(j_arr.size(), 1) << "MetaInfo: " << c_key << ". " << ArrayInterfaceErrors::Dimension(1); - ArrayInterface array_interface(interface_str); + ArrayInterface array_interface(get(j_arr[0])); std::string key{c_key}; CHECK(!array_interface.valid.Data()) << "Meta info " << key << " should be dense, found validity mask"; diff --git a/tests/cpp/data/test_array_interface.cc b/tests/cpp/data/test_array_interface.cc deleted file mode 100644 index 5fe93ffa4..000000000 --- a/tests/cpp/data/test_array_interface.cc +++ /dev/null @@ -1,51 +0,0 @@ -/*! - * Copyright 2020 by XGBoost Contributors - */ -#include -#include -#include "../helpers.h" -#include "../../../src/data/array_interface.h" - -namespace xgboost { -TEST(ArrayInterface, Initialize) { - size_t constexpr kRows = 10, kCols = 10; - HostDeviceVector storage; - auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); - auto arr_interface = ArrayInterface(array); - ASSERT_EQ(arr_interface.num_rows, kRows); - ASSERT_EQ(arr_interface.num_cols, kCols); - ASSERT_EQ(arr_interface.data, storage.ConstHostPointer()); -} - -TEST(ArrayInterface, Error) { - constexpr size_t kRows = 16, kCols = 10; - Json column { Object() }; - std::vector j_shape {Json(Integer(static_cast(kRows)))}; - column["shape"] = Array(j_shape); - std::vector j_data { - Json(Integer(reinterpret_cast(nullptr))), - Json(Boolean(false))}; - - auto const& column_obj = get(column); - // missing version - EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj), dmlc::Error); - column["version"] = Integer(static_cast(1)); - // missing data - EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj), dmlc::Error); - column["data"] = j_data; - // missing typestr - EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj), dmlc::Error); - column["typestr"] = String("(column_obj), dmlc::Error); - - HostDeviceVector storage; - auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage); - j_data = { - Json(Integer(reinterpret_cast(storage.ConstHostPointer()))), - Json(Boolean(false))}; - column["data"] = j_data; - EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj)); -} - -} // namespace xgboost diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 2274e57e7..893891d13 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -182,13 +182,7 @@ Json RandomDataGenerator::ArrayInterfaceImpl(HostDeviceVector *storage, this->GenerateDense(storage); Json array_interface {Object()}; array_interface["data"] = std::vector(2); - if (storage->DeviceCanRead()) { - array_interface["data"][0] = - Integer(reinterpret_cast(storage->ConstDevicePointer())); - } else { - array_interface["data"][0] = - Integer(reinterpret_cast(storage->ConstHostPointer())); - } + array_interface["data"][0] = Integer(reinterpret_cast(storage->DevicePointer())); array_interface["data"][1] = Boolean(false); array_interface["shape"] = std::vector(2);