Support host data in proxy DMatrix. (#7087)

This commit is contained in:
Jiaming Yuan
2021-07-08 11:35:48 +08:00
committed by GitHub
parent 5d7cdf2e36
commit 84d359efb8
9 changed files with 188 additions and 30 deletions

View File

@@ -257,7 +257,10 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
Line const GetLine(size_t idx) const {
return Line{array_interface_, idx};
}
size_t Size() const { return array_interface_.num_rows; }
size_t NumRows() const { return array_interface_.num_rows; }
size_t NumCols() const { return array_interface_.num_cols; }
size_t Size() const { return this->NumRows(); }
explicit ArrayAdapterBatch(ArrayInterface array_interface)
: array_interface_{std::move(array_interface)} {}
@@ -288,6 +291,7 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
ArrayInterface indptr_;
ArrayInterface indices_;
ArrayInterface values_;
bst_feature_t n_features_;
class Line {
ArrayInterface indices_;
@@ -311,23 +315,27 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
}
};
public:
static constexpr bool kIsRowMajor = true;
public:
CSRArrayAdapterBatch() = default;
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
ArrayInterface values)
ArrayInterface values, bst_feature_t n_features)
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
values_{std::move(values)} {
values_{std::move(values)}, n_features_{n_features} {
indptr_.AsColumnVector();
values_.AsColumnVector();
indices_.AsColumnVector();
}
size_t Size() const {
size_t NumRows() const {
size_t size = indptr_.num_rows * indptr_.num_cols;
size = size == 0 ? 0 : size - 1;
return size;
}
static constexpr bool kIsRowMajor = true;
size_t NumCols() const { return n_features_; }
size_t Size() const { return this->NumRows(); }
Line const GetLine(size_t idx) const {
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
@@ -356,7 +364,8 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch>
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
size_t num_cols)
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_};
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_,
static_cast<bst_feature_t>(num_cols_)};
}
CSRArrayAdapterBatch const& Value() const override {

View File

@@ -11,6 +11,7 @@
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "xgboost/base.h"
#include "xgboost/data.h"
@@ -416,5 +417,24 @@ class ArrayInterface {
Type type;
};
template <typename T> std::string MakeArrayInterface(T const *data, size_t n) {
Json arr{Object{}};
arr["data"] = Array(std::vector<Json>{
Json{Integer{reinterpret_cast<int64_t>(data)}}, Json{Boolean{false}}});
arr["shape"] = Array{std::vector<Json>{Json{Integer{n}}, Json{Integer{1}}}};
std::string typestr;
if (DMLC_LITTLE_ENDIAN) {
typestr.push_back('<');
} else {
typestr.push_back('>');
}
typestr.push_back(ArrayInterfaceHandler::TypeChar<T>());
typestr += std::to_string(sizeof(T));
arr["typestr"] = typestr;
arr["version"] = 3;
std::string str;
Json::Dump(arr, &str);
return str;
}
} // namespace xgboost
#endif // XGBOOST_DATA_ARRAY_INTERFACE_H_

View File

@@ -11,35 +11,16 @@
#include "sparse_page_source.h"
#include "ellpack_page.cuh"
#include "proxy_dmatrix.h"
#include "proxy_dmatrix.cuh"
#include "device_adapter.cuh"
namespace xgboost {
namespace data {
template <typename Fn>
decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) {
if (proxy->Adapter().type() == typeid(std::shared_ptr<CupyAdapter>)) {
auto value = dmlc::get<std::shared_ptr<CupyAdapter>>(
proxy->Adapter())->Value();
return fn(value);
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<CudfAdapter>)) {
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
proxy->Adapter())->Value();
return fn(value);
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
proxy->Adapter())->Value();
return fn(value);
}
}
void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missing, int nthread) {
// A handle passed to external iterator.
auto handle = static_cast<std::shared_ptr<DMatrix>*>(proxy_);
CHECK(handle);
DMatrixProxy* proxy = static_cast<DMatrixProxy*>(handle->get());
DMatrixProxy* proxy = MakeProxy(proxy_);
CHECK(proxy);
// The external iterator
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
iter_handle, reset_, next_};

29
src/data/proxy_dmatrix.cc Normal file
View File

@@ -0,0 +1,29 @@
/*!
* Copyright 2021 by Contributors
* \file proxy_dmatrix.cc
*/
#include "proxy_dmatrix.h"
namespace xgboost {
namespace data {
void DMatrixProxy::SetArrayData(char const *c_interface) {
std::shared_ptr<ArrayAdapter> adapter{
new ArrayAdapter(StringView{c_interface})};
this->batch_ = adapter;
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
}
void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices,
char const *c_values, bst_feature_t n_features, bool on_host) {
CHECK(on_host) << "Not implemented on device.";
std::shared_ptr<CSRArrayAdapter> adapter{
new CSRArrayAdapter(StringView{c_indptr}, StringView{c_indices},
StringView{c_values}, n_features)};
this->batch_ = adapter;
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
}
} // namespace data
} // namespace xgboost

View File

@@ -0,0 +1,27 @@
/*!
* Copyright 2021 XGBoost contributors
*/
#include "device_adapter.cuh"
#include "proxy_dmatrix.h"
namespace xgboost {
namespace data {
template <typename Fn>
decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) {
if (proxy->Adapter().type() == typeid(std::shared_ptr<CupyAdapter>)) {
auto value = dmlc::get<std::shared_ptr<CupyAdapter>>(
proxy->Adapter())->Value();
return fn(value);
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<CudfAdapter>)) {
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
proxy->Adapter())->Value();
return fn(value);
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
auto value = dmlc::get<std::shared_ptr<CudfAdapter>>(
proxy->Adapter())->Value();
return fn(value);
}
}
} // namespace data
} // namespace xgboost

View File

@@ -72,6 +72,11 @@ class DMatrixProxy : public DMatrix {
#endif // defined(XGBOOST_USE_CUDA)
}
void SetArrayData(char const* c_interface);
void SetCSRData(char const *c_indptr, char const *c_indices,
char const *c_values, bst_feature_t n_features,
bool on_host);
MetaInfo& Info() override { return info_; }
MetaInfo const& Info() const override { return info_; }
bool SingleColBlock() const override { return true; }
@@ -106,6 +111,41 @@ class DMatrixProxy : public DMatrix {
return batch_;
}
};
inline DMatrixProxy *MakeProxy(DMatrixHandle proxy) {
auto proxy_handle = static_cast<std::shared_ptr<DMatrix> *>(proxy);
CHECK(proxy_handle) << "Invalid proxy handle.";
DMatrixProxy *typed = static_cast<DMatrixProxy *>(proxy_handle->get());
return typed;
}
template <typename Fn>
decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) {
if (proxy->Adapter().type() == typeid(std::shared_ptr<CSRArrayAdapter>)) {
auto value =
dmlc::get<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
if (type_error) {
*type_error = false;
}
return fn(value);
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<ArrayAdapter>)) {
auto value = dmlc::get<std::shared_ptr<ArrayAdapter>>(
proxy->Adapter())->Value();
if (type_error) {
*type_error = false;
}
return fn(value);
} else {
if (type_error) {
*type_error = true;
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
}
auto value = dmlc::get<std::shared_ptr<ArrayAdapter>>(
proxy->Adapter())->Value();
return fn(value);
}
}
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_PROXY_DMATRIX_H_