diff --git a/src/data/adapter.h b/src/data/adapter.h index 924fb9f82..8502ebd34 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -257,7 +257,10 @@ class ArrayAdapterBatch : public detail::NoMetaInfo { Line const GetLine(size_t idx) const { return Line{array_interface_, idx}; } - size_t Size() const { return array_interface_.num_rows; } + + size_t NumRows() const { return array_interface_.num_rows; } + size_t NumCols() const { return array_interface_.num_cols; } + size_t Size() const { return this->NumRows(); } explicit ArrayAdapterBatch(ArrayInterface array_interface) : array_interface_{std::move(array_interface)} {} @@ -288,6 +291,7 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo { ArrayInterface indptr_; ArrayInterface indices_; ArrayInterface values_; + bst_feature_t n_features_; class Line { ArrayInterface indices_; @@ -311,23 +315,27 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo { } }; + public: + static constexpr bool kIsRowMajor = true; + public: CSRArrayAdapterBatch() = default; CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices, - ArrayInterface values) + ArrayInterface values, bst_feature_t n_features) : indptr_{std::move(indptr)}, indices_{std::move(indices)}, - values_{std::move(values)} { + values_{std::move(values)}, n_features_{n_features} { indptr_.AsColumnVector(); values_.AsColumnVector(); indices_.AsColumnVector(); } - size_t Size() const { + size_t NumRows() const { size_t size = indptr_.num_rows * indptr_.num_cols; size = size == 0 ? 0 : size - 1; return size; } - static constexpr bool kIsRowMajor = true; + size_t NumCols() const { return n_features_; } + size_t Size() const { return this->NumRows(); } Line const GetLine(size_t idx) const { auto begin_offset = indptr_.GetElement(idx, 0); @@ -356,7 +364,8 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter CSRArrayAdapter(StringView indptr, StringView indices, StringView values, size_t num_cols) : indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} { - batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_}; + batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_, + static_cast(num_cols_)}; } CSRArrayAdapterBatch const& Value() const override { diff --git a/src/data/array_interface.h b/src/data/array_interface.h index fd79b1348..b7ca31143 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "xgboost/base.h" #include "xgboost/data.h" @@ -416,5 +417,24 @@ class ArrayInterface { Type type; }; +template std::string MakeArrayInterface(T const *data, size_t n) { + Json arr{Object{}}; + arr["data"] = Array(std::vector{ + Json{Integer{reinterpret_cast(data)}}, Json{Boolean{false}}}); + arr["shape"] = Array{std::vector{Json{Integer{n}}, Json{Integer{1}}}}; + std::string typestr; + if (DMLC_LITTLE_ENDIAN) { + typestr.push_back('<'); + } else { + typestr.push_back('>'); + } + typestr.push_back(ArrayInterfaceHandler::TypeChar()); + typestr += std::to_string(sizeof(T)); + arr["typestr"] = typestr; + arr["version"] = 3; + std::string str; + Json::Dump(arr, &str); + return str; +} } // namespace xgboost #endif // XGBOOST_DATA_ARRAY_INTERFACE_H_ diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index 87fd4af93..4b584a4e7 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -11,35 +11,16 @@ #include "sparse_page_source.h" #include "ellpack_page.cuh" #include "proxy_dmatrix.h" +#include "proxy_dmatrix.cuh" #include "device_adapter.cuh" namespace xgboost { namespace data { - -template -decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) { - if (proxy->Adapter().type() == typeid(std::shared_ptr)) { - auto value = dmlc::get>( - proxy->Adapter())->Value(); - return fn(value); - } else if (proxy->Adapter().type() == typeid(std::shared_ptr)) { - auto value = dmlc::get>( - proxy->Adapter())->Value(); - return fn(value); - } else { - LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name(); - auto value = dmlc::get>( - proxy->Adapter())->Value(); - return fn(value); - } -} - void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missing, int nthread) { // A handle passed to external iterator. - auto handle = static_cast*>(proxy_); - CHECK(handle); - DMatrixProxy* proxy = static_cast(handle->get()); + DMatrixProxy* proxy = MakeProxy(proxy_); CHECK(proxy); + // The external iterator auto iter = DataIterProxy{ iter_handle, reset_, next_}; diff --git a/src/data/proxy_dmatrix.cc b/src/data/proxy_dmatrix.cc new file mode 100644 index 000000000..0c60891a3 --- /dev/null +++ b/src/data/proxy_dmatrix.cc @@ -0,0 +1,29 @@ +/*! + * Copyright 2021 by Contributors + * \file proxy_dmatrix.cc + */ + +#include "proxy_dmatrix.h" + +namespace xgboost { +namespace data { +void DMatrixProxy::SetArrayData(char const *c_interface) { + std::shared_ptr adapter{ + new ArrayAdapter(StringView{c_interface})}; + this->batch_ = adapter; + this->Info().num_col_ = adapter->NumColumns(); + this->Info().num_row_ = adapter->NumRows(); +} + +void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices, + char const *c_values, bst_feature_t n_features, bool on_host) { + CHECK(on_host) << "Not implemented on device."; + std::shared_ptr adapter{ + new CSRArrayAdapter(StringView{c_indptr}, StringView{c_indices}, + StringView{c_values}, n_features)}; + this->batch_ = adapter; + this->Info().num_col_ = adapter->NumColumns(); + this->Info().num_row_ = adapter->NumRows(); +} +} // namespace data +} // namespace xgboost diff --git a/src/data/proxy_dmatrix.cuh b/src/data/proxy_dmatrix.cuh new file mode 100644 index 000000000..38cbffe50 --- /dev/null +++ b/src/data/proxy_dmatrix.cuh @@ -0,0 +1,27 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#include "device_adapter.cuh" +#include "proxy_dmatrix.h" + +namespace xgboost { +namespace data { +template +decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) { + if (proxy->Adapter().type() == typeid(std::shared_ptr)) { + auto value = dmlc::get>( + proxy->Adapter())->Value(); + return fn(value); + } else if (proxy->Adapter().type() == typeid(std::shared_ptr)) { + auto value = dmlc::get>( + proxy->Adapter())->Value(); + return fn(value); + } else { + LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name(); + auto value = dmlc::get>( + proxy->Adapter())->Value(); + return fn(value); + } +} +} // namespace data +} // namespace xgboost diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index cb5cc6e02..4baeec0b3 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -72,6 +72,11 @@ class DMatrixProxy : public DMatrix { #endif // defined(XGBOOST_USE_CUDA) } + void SetArrayData(char const* c_interface); + void SetCSRData(char const *c_indptr, char const *c_indices, + char const *c_values, bst_feature_t n_features, + bool on_host); + MetaInfo& Info() override { return info_; } MetaInfo const& Info() const override { return info_; } bool SingleColBlock() const override { return true; } @@ -106,6 +111,41 @@ class DMatrixProxy : public DMatrix { return batch_; } }; + +inline DMatrixProxy *MakeProxy(DMatrixHandle proxy) { + auto proxy_handle = static_cast *>(proxy); + CHECK(proxy_handle) << "Invalid proxy handle."; + DMatrixProxy *typed = static_cast(proxy_handle->get()); + return typed; +} + +template +decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) { + if (proxy->Adapter().type() == typeid(std::shared_ptr)) { + auto value = + dmlc::get>(proxy->Adapter())->Value(); + if (type_error) { + *type_error = false; + } + return fn(value); + } else if (proxy->Adapter().type() == typeid(std::shared_ptr)) { + auto value = dmlc::get>( + proxy->Adapter())->Value(); + if (type_error) { + *type_error = false; + } + return fn(value); + } else { + if (type_error) { + *type_error = true; + } else { + LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name(); + } + auto value = dmlc::get>( + proxy->Adapter())->Value(); + return fn(value); + } +} } // namespace data } // namespace xgboost #endif // XGBOOST_DATA_PROXY_DMATRIX_H_ diff --git a/tests/cpp/data/test_adapter.cc b/tests/cpp/data/test_adapter.cc index c833b7503..ccb19de71 100644 --- a/tests/cpp/data/test_adapter.cc +++ b/tests/cpp/data/test_adapter.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019 by Contributors +// Copyright (c) 2019-2021 by XGBoost Contributors #include #include #include @@ -35,6 +35,27 @@ TEST(Adapter, CSRAdapter) { EXPECT_EQ(line2.GetElement(0).column_idx, 1); } +TEST(Adapter, CSRArrayAdapter) { + HostDeviceVector indptr; + HostDeviceVector values; + HostDeviceVector indices; + size_t n_features = 100, n_samples = 10; + RandomDataGenerator{n_samples, n_features, 0.5}.GenerateCSR(&values, &indptr, &indices); + auto indptr_arr = MakeArrayInterface(indptr.HostPointer(), indptr.Size()); + auto values_arr = MakeArrayInterface(values.HostPointer(), values.Size()); + auto indices_arr = MakeArrayInterface(indices.HostPointer(), indices.Size()); + auto adapter = data::CSRArrayAdapter( + StringView{indptr_arr.c_str(), indptr_arr.size()}, + StringView{values_arr.c_str(), values_arr.size()}, + StringView{indices_arr.c_str(), indices_arr.size()}, n_features); + auto batch = adapter.Value(); + ASSERT_EQ(batch.NumRows(), n_samples); + ASSERT_EQ(batch.NumCols(), n_features); + + ASSERT_EQ(adapter.NumRows(), n_samples); + ASSERT_EQ(adapter.NumColumns(), n_features); +} + TEST(Adapter, CSCAdapterColsMoreThanRows) { std::vector data = {1, 2, 3, 4, 5, 6, 7, 8}; std::vector row_idx = {0, 1, 0, 1, 0, 1, 0, 1}; diff --git a/tests/cpp/data/test_proxy_dmatrix.cc b/tests/cpp/data/test_proxy_dmatrix.cc new file mode 100644 index 000000000..a6d0b2188 --- /dev/null +++ b/tests/cpp/data/test_proxy_dmatrix.cc @@ -0,0 +1,31 @@ +/*! + * Copyright 2021 XGBoost contributors + */ +#include +#include "../helpers.h" +#include "../../../src/data/proxy_dmatrix.h" +#include "../../../src/data/adapter.h" + +namespace xgboost { +namespace data { +TEST(ProxyDMatrix, HostData) { + DMatrixProxy proxy; + size_t constexpr kRows = 100, kCols = 10; + std::vector> label_storage(1); + + HostDeviceVector storage; + auto data = RandomDataGenerator(kRows, kCols, 0.5) + .Device(0) + .GenerateArrayInterface(&storage); + + proxy.SetArrayData(data.c_str()); + + auto n_samples = HostAdapterDispatch( + &proxy, [](auto const &value) { return value.Size(); }); + ASSERT_EQ(n_samples, kRows); + auto n_features = HostAdapterDispatch( + &proxy, [](auto const &value) { return value.NumCols(); }); + ASSERT_EQ(n_features, kCols); +} +} // namespace data +} // namespace xgboost diff --git a/tests/cpp/data/test_proxy_dmatrix.cu b/tests/cpp/data/test_proxy_dmatrix.cu index 5460995d9..19aa7a3ee 100644 --- a/tests/cpp/data/test_proxy_dmatrix.cu +++ b/tests/cpp/data/test_proxy_dmatrix.cu @@ -7,7 +7,7 @@ namespace xgboost { namespace data { -TEST(ProxyDMatrix, Basic) { +TEST(ProxyDMatrix, DeviceData) { constexpr size_t kRows{100}, kCols{100}; HostDeviceVector storage; auto data = RandomDataGenerator(kRows, kCols, 0.5)