Support host data in proxy DMatrix. (#7087)
This commit is contained in:
parent
5d7cdf2e36
commit
84d359efb8
@ -257,7 +257,10 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
|
|||||||
Line const GetLine(size_t idx) const {
|
Line const GetLine(size_t idx) const {
|
||||||
return Line{array_interface_, idx};
|
return Line{array_interface_, idx};
|
||||||
}
|
}
|
||||||
size_t Size() const { return array_interface_.num_rows; }
|
|
||||||
|
size_t NumRows() const { return array_interface_.num_rows; }
|
||||||
|
size_t NumCols() const { return array_interface_.num_cols; }
|
||||||
|
size_t Size() const { return this->NumRows(); }
|
||||||
|
|
||||||
explicit ArrayAdapterBatch(ArrayInterface array_interface)
|
explicit ArrayAdapterBatch(ArrayInterface array_interface)
|
||||||
: array_interface_{std::move(array_interface)} {}
|
: array_interface_{std::move(array_interface)} {}
|
||||||
@ -288,6 +291,7 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
|||||||
ArrayInterface indptr_;
|
ArrayInterface indptr_;
|
||||||
ArrayInterface indices_;
|
ArrayInterface indices_;
|
||||||
ArrayInterface values_;
|
ArrayInterface values_;
|
||||||
|
bst_feature_t n_features_;
|
||||||
|
|
||||||
class Line {
|
class Line {
|
||||||
ArrayInterface indices_;
|
ArrayInterface indices_;
|
||||||
@ -311,23 +315,27 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
public:
|
||||||
|
static constexpr bool kIsRowMajor = true;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
CSRArrayAdapterBatch() = default;
|
CSRArrayAdapterBatch() = default;
|
||||||
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
|
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
|
||||||
ArrayInterface values)
|
ArrayInterface values, bst_feature_t n_features)
|
||||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||||
values_{std::move(values)} {
|
values_{std::move(values)}, n_features_{n_features} {
|
||||||
indptr_.AsColumnVector();
|
indptr_.AsColumnVector();
|
||||||
values_.AsColumnVector();
|
values_.AsColumnVector();
|
||||||
indices_.AsColumnVector();
|
indices_.AsColumnVector();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t Size() const {
|
size_t NumRows() const {
|
||||||
size_t size = indptr_.num_rows * indptr_.num_cols;
|
size_t size = indptr_.num_rows * indptr_.num_cols;
|
||||||
size = size == 0 ? 0 : size - 1;
|
size = size == 0 ? 0 : size - 1;
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
static constexpr bool kIsRowMajor = true;
|
size_t NumCols() const { return n_features_; }
|
||||||
|
size_t Size() const { return this->NumRows(); }
|
||||||
|
|
||||||
Line const GetLine(size_t idx) const {
|
Line const GetLine(size_t idx) const {
|
||||||
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
|
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
|
||||||
@ -356,7 +364,8 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch>
|
|||||||
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
|
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
|
||||||
size_t num_cols)
|
size_t num_cols)
|
||||||
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
|
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
|
||||||
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_};
|
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_,
|
||||||
|
static_cast<bst_feature_t>(num_cols_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
CSRArrayAdapterBatch const& Value() const override {
|
CSRArrayAdapterBatch const& Value() const override {
|
||||||
|
|||||||
@ -11,6 +11,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
@ -416,5 +417,24 @@ class ArrayInterface {
|
|||||||
Type type;
|
Type type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T> std::string MakeArrayInterface(T const *data, size_t n) {
|
||||||
|
Json arr{Object{}};
|
||||||
|
arr["data"] = Array(std::vector<Json>{
|
||||||
|
Json{Integer{reinterpret_cast<int64_t>(data)}}, Json{Boolean{false}}});
|
||||||
|
arr["shape"] = Array{std::vector<Json>{Json{Integer{n}}, Json{Integer{1}}}};
|
||||||
|
std::string typestr;
|
||||||
|
if (DMLC_LITTLE_ENDIAN) {
|
||||||
|
typestr.push_back('<');
|
||||||
|
} else {
|
||||||
|
typestr.push_back('>');
|
||||||
|
}
|
||||||
|
typestr.push_back(ArrayInterfaceHandler::TypeChar<T>());
|
||||||
|
typestr += std::to_string(sizeof(T));
|
||||||
|
arr["typestr"] = typestr;
|
||||||
|
arr["version"] = 3;
|
||||||
|
std::string str;
|
||||||
|
Json::Dump(arr, &str);
|
||||||
|
return str;
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_DATA_ARRAY_INTERFACE_H_
|
#endif // XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||||
|
|||||||
@ -11,35 +11,16 @@
|
|||||||
#include "sparse_page_source.h"
|
#include "sparse_page_source.h"
|
||||||
#include "ellpack_page.cuh"
|
#include "ellpack_page.cuh"
|
||||||
#include "proxy_dmatrix.h"
|
#include "proxy_dmatrix.h"
|
||||||
|
#include "proxy_dmatrix.cuh"
|
||||||
#include "device_adapter.cuh"
|
#include "device_adapter.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
template <typename Fn>
|
|
||||||
decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) {
|
|
||||||
if (proxy->Adapter().type() == typeid(std::shared_ptr<CupyAdapter>)) {
|
|
||||||
auto value = dmlc::get<std::shared_ptr<CupyAdapter>>(
|
|
||||||
proxy->Adapter())->Value();
|
|
||||||
return fn(value);
|
|
||||||
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<CudfAdapter>)) {
|
|
||||||
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
|
|
||||||
proxy->Adapter())->Value();
|
|
||||||
return fn(value);
|
|
||||||
} else {
|
|
||||||
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
|
||||||
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
|
|
||||||
proxy->Adapter())->Value();
|
|
||||||
return fn(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missing, int nthread) {
|
void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missing, int nthread) {
|
||||||
// A handle passed to external iterator.
|
// A handle passed to external iterator.
|
||||||
auto handle = static_cast<std::shared_ptr<DMatrix>*>(proxy_);
|
DMatrixProxy* proxy = MakeProxy(proxy_);
|
||||||
CHECK(handle);
|
|
||||||
DMatrixProxy* proxy = static_cast<DMatrixProxy*>(handle->get());
|
|
||||||
CHECK(proxy);
|
CHECK(proxy);
|
||||||
|
|
||||||
// The external iterator
|
// The external iterator
|
||||||
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
|
||||||
iter_handle, reset_, next_};
|
iter_handle, reset_, next_};
|
||||||
|
|||||||
29
src/data/proxy_dmatrix.cc
Normal file
29
src/data/proxy_dmatrix.cc
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 by Contributors
|
||||||
|
* \file proxy_dmatrix.cc
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "proxy_dmatrix.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace data {
|
||||||
|
void DMatrixProxy::SetArrayData(char const *c_interface) {
|
||||||
|
std::shared_ptr<ArrayAdapter> adapter{
|
||||||
|
new ArrayAdapter(StringView{c_interface})};
|
||||||
|
this->batch_ = adapter;
|
||||||
|
this->Info().num_col_ = adapter->NumColumns();
|
||||||
|
this->Info().num_row_ = adapter->NumRows();
|
||||||
|
}
|
||||||
|
|
||||||
|
void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices,
|
||||||
|
char const *c_values, bst_feature_t n_features, bool on_host) {
|
||||||
|
CHECK(on_host) << "Not implemented on device.";
|
||||||
|
std::shared_ptr<CSRArrayAdapter> adapter{
|
||||||
|
new CSRArrayAdapter(StringView{c_indptr}, StringView{c_indices},
|
||||||
|
StringView{c_values}, n_features)};
|
||||||
|
this->batch_ = adapter;
|
||||||
|
this->Info().num_col_ = adapter->NumColumns();
|
||||||
|
this->Info().num_row_ = adapter->NumRows();
|
||||||
|
}
|
||||||
|
} // namespace data
|
||||||
|
} // namespace xgboost
|
||||||
27
src/data/proxy_dmatrix.cuh
Normal file
27
src/data/proxy_dmatrix.cuh
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include "device_adapter.cuh"
|
||||||
|
#include "proxy_dmatrix.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace data {
|
||||||
|
template <typename Fn>
|
||||||
|
decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) {
|
||||||
|
if (proxy->Adapter().type() == typeid(std::shared_ptr<CupyAdapter>)) {
|
||||||
|
auto value = dmlc::get<std::shared_ptr<CupyAdapter>>(
|
||||||
|
proxy->Adapter())->Value();
|
||||||
|
return fn(value);
|
||||||
|
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<CudfAdapter>)) {
|
||||||
|
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
|
||||||
|
proxy->Adapter())->Value();
|
||||||
|
return fn(value);
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
||||||
|
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
|
||||||
|
proxy->Adapter())->Value();
|
||||||
|
return fn(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace data
|
||||||
|
} // namespace xgboost
|
||||||
@ -72,6 +72,11 @@ class DMatrixProxy : public DMatrix {
|
|||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetArrayData(char const* c_interface);
|
||||||
|
void SetCSRData(char const *c_indptr, char const *c_indices,
|
||||||
|
char const *c_values, bst_feature_t n_features,
|
||||||
|
bool on_host);
|
||||||
|
|
||||||
MetaInfo& Info() override { return info_; }
|
MetaInfo& Info() override { return info_; }
|
||||||
MetaInfo const& Info() const override { return info_; }
|
MetaInfo const& Info() const override { return info_; }
|
||||||
bool SingleColBlock() const override { return true; }
|
bool SingleColBlock() const override { return true; }
|
||||||
@ -106,6 +111,41 @@ class DMatrixProxy : public DMatrix {
|
|||||||
return batch_;
|
return batch_;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
inline DMatrixProxy *MakeProxy(DMatrixHandle proxy) {
|
||||||
|
auto proxy_handle = static_cast<std::shared_ptr<DMatrix> *>(proxy);
|
||||||
|
CHECK(proxy_handle) << "Invalid proxy handle.";
|
||||||
|
DMatrixProxy *typed = static_cast<DMatrixProxy *>(proxy_handle->get());
|
||||||
|
return typed;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Fn>
|
||||||
|
decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) {
|
||||||
|
if (proxy->Adapter().type() == typeid(std::shared_ptr<CSRArrayAdapter>)) {
|
||||||
|
auto value =
|
||||||
|
dmlc::get<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
|
||||||
|
if (type_error) {
|
||||||
|
*type_error = false;
|
||||||
|
}
|
||||||
|
return fn(value);
|
||||||
|
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<ArrayAdapter>)) {
|
||||||
|
auto value = dmlc::get<std::shared_ptr<ArrayAdapter>>(
|
||||||
|
proxy->Adapter())->Value();
|
||||||
|
if (type_error) {
|
||||||
|
*type_error = false;
|
||||||
|
}
|
||||||
|
return fn(value);
|
||||||
|
} else {
|
||||||
|
if (type_error) {
|
||||||
|
*type_error = true;
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
||||||
|
}
|
||||||
|
auto value = dmlc::get<std::shared_ptr<ArrayAdapter>>(
|
||||||
|
proxy->Adapter())->Value();
|
||||||
|
return fn(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_DATA_PROXY_DMATRIX_H_
|
#endif // XGBOOST_DATA_PROXY_DMATRIX_H_
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
// Copyright (c) 2019 by Contributors
|
// Copyright (c) 2019-2021 by XGBoost Contributors
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@ -35,6 +35,27 @@ TEST(Adapter, CSRAdapter) {
|
|||||||
EXPECT_EQ(line2.GetElement(0).column_idx, 1);
|
EXPECT_EQ(line2.GetElement(0).column_idx, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Adapter, CSRArrayAdapter) {
|
||||||
|
HostDeviceVector<bst_row_t> indptr;
|
||||||
|
HostDeviceVector<float> values;
|
||||||
|
HostDeviceVector<bst_feature_t> indices;
|
||||||
|
size_t n_features = 100, n_samples = 10;
|
||||||
|
RandomDataGenerator{n_samples, n_features, 0.5}.GenerateCSR(&values, &indptr, &indices);
|
||||||
|
auto indptr_arr = MakeArrayInterface(indptr.HostPointer(), indptr.Size());
|
||||||
|
auto values_arr = MakeArrayInterface(values.HostPointer(), values.Size());
|
||||||
|
auto indices_arr = MakeArrayInterface(indices.HostPointer(), indices.Size());
|
||||||
|
auto adapter = data::CSRArrayAdapter(
|
||||||
|
StringView{indptr_arr.c_str(), indptr_arr.size()},
|
||||||
|
StringView{values_arr.c_str(), values_arr.size()},
|
||||||
|
StringView{indices_arr.c_str(), indices_arr.size()}, n_features);
|
||||||
|
auto batch = adapter.Value();
|
||||||
|
ASSERT_EQ(batch.NumRows(), n_samples);
|
||||||
|
ASSERT_EQ(batch.NumCols(), n_features);
|
||||||
|
|
||||||
|
ASSERT_EQ(adapter.NumRows(), n_samples);
|
||||||
|
ASSERT_EQ(adapter.NumColumns(), n_features);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Adapter, CSCAdapterColsMoreThanRows) {
|
TEST(Adapter, CSCAdapterColsMoreThanRows) {
|
||||||
std::vector<float> data = {1, 2, 3, 4, 5, 6, 7, 8};
|
std::vector<float> data = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||||
std::vector<unsigned> row_idx = {0, 1, 0, 1, 0, 1, 0, 1};
|
std::vector<unsigned> row_idx = {0, 1, 0, 1, 0, 1, 0, 1};
|
||||||
|
|||||||
31
tests/cpp/data/test_proxy_dmatrix.cc
Normal file
31
tests/cpp/data/test_proxy_dmatrix.cc
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include "../helpers.h"
|
||||||
|
#include "../../../src/data/proxy_dmatrix.h"
|
||||||
|
#include "../../../src/data/adapter.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace data {
|
||||||
|
TEST(ProxyDMatrix, HostData) {
|
||||||
|
DMatrixProxy proxy;
|
||||||
|
size_t constexpr kRows = 100, kCols = 10;
|
||||||
|
std::vector<HostDeviceVector<float>> label_storage(1);
|
||||||
|
|
||||||
|
HostDeviceVector<float> storage;
|
||||||
|
auto data = RandomDataGenerator(kRows, kCols, 0.5)
|
||||||
|
.Device(0)
|
||||||
|
.GenerateArrayInterface(&storage);
|
||||||
|
|
||||||
|
proxy.SetArrayData(data.c_str());
|
||||||
|
|
||||||
|
auto n_samples = HostAdapterDispatch(
|
||||||
|
&proxy, [](auto const &value) { return value.Size(); });
|
||||||
|
ASSERT_EQ(n_samples, kRows);
|
||||||
|
auto n_features = HostAdapterDispatch(
|
||||||
|
&proxy, [](auto const &value) { return value.NumCols(); });
|
||||||
|
ASSERT_EQ(n_features, kCols);
|
||||||
|
}
|
||||||
|
} // namespace data
|
||||||
|
} // namespace xgboost
|
||||||
@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
TEST(ProxyDMatrix, Basic) {
|
TEST(ProxyDMatrix, DeviceData) {
|
||||||
constexpr size_t kRows{100}, kCols{100};
|
constexpr size_t kRows{100}, kCols{100};
|
||||||
HostDeviceVector<float> storage;
|
HostDeviceVector<float> storage;
|
||||||
auto data = RandomDataGenerator(kRows, kCols, 0.5)
|
auto data = RandomDataGenerator(kRows, kCols, 0.5)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user