Support host data in proxy DMatrix. (#7087)
This commit is contained in:
@@ -72,6 +72,11 @@ class DMatrixProxy : public DMatrix {
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
void SetArrayData(char const* c_interface);
|
||||
void SetCSRData(char const *c_indptr, char const *c_indices,
|
||||
char const *c_values, bst_feature_t n_features,
|
||||
bool on_host);
|
||||
|
||||
MetaInfo& Info() override { return info_; }
|
||||
MetaInfo const& Info() const override { return info_; }
|
||||
bool SingleColBlock() const override { return true; }
|
||||
@@ -106,6 +111,41 @@ class DMatrixProxy : public DMatrix {
|
||||
return batch_;
|
||||
}
|
||||
};
|
||||
|
||||
inline DMatrixProxy *MakeProxy(DMatrixHandle proxy) {
|
||||
auto proxy_handle = static_cast<std::shared_ptr<DMatrix> *>(proxy);
|
||||
CHECK(proxy_handle) << "Invalid proxy handle.";
|
||||
DMatrixProxy *typed = static_cast<DMatrixProxy *>(proxy_handle->get());
|
||||
return typed;
|
||||
}
|
||||
|
||||
template <typename Fn>
|
||||
decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) {
|
||||
if (proxy->Adapter().type() == typeid(std::shared_ptr<CSRArrayAdapter>)) {
|
||||
auto value =
|
||||
dmlc::get<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
|
||||
if (type_error) {
|
||||
*type_error = false;
|
||||
}
|
||||
return fn(value);
|
||||
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<ArrayAdapter>)) {
|
||||
auto value = dmlc::get<std::shared_ptr<ArrayAdapter>>(
|
||||
proxy->Adapter())->Value();
|
||||
if (type_error) {
|
||||
*type_error = false;
|
||||
}
|
||||
return fn(value);
|
||||
} else {
|
||||
if (type_error) {
|
||||
*type_error = true;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
||||
}
|
||||
auto value = dmlc::get<std::shared_ptr<ArrayAdapter>>(
|
||||
proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_PROXY_DMATRIX_H_
|
||||
|
||||
Reference in New Issue
Block a user