initial merge
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
/*!
|
||||
* Copyright 2020-2022, XGBoost contributors
|
||||
/**
|
||||
* Copyright 2020-2023, XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_PROXY_DMATRIX_H_
|
||||
#define XGBOOST_DATA_PROXY_DMATRIX_H_
|
||||
|
||||
#include <dmlc/any.h>
|
||||
|
||||
#include <any> // for any, any_cast
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
@@ -15,8 +14,7 @@
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
namespace xgboost::data {
|
||||
/*
|
||||
* \brief A proxy to external iterator.
|
||||
*/
|
||||
@@ -44,7 +42,7 @@ class DataIterProxy {
|
||||
*/
|
||||
class DMatrixProxy : public DMatrix {
|
||||
MetaInfo info_;
|
||||
dmlc::any batch_;
|
||||
std::any batch_;
|
||||
Context ctx_;
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
@@ -115,9 +113,7 @@ class DMatrixProxy : public DMatrix {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
return BatchSet<ExtSparsePage>(BatchIterator<ExtSparsePage>(nullptr));
|
||||
}
|
||||
dmlc::any Adapter() const {
|
||||
return batch_;
|
||||
}
|
||||
std::any Adapter() const { return batch_; }
|
||||
};
|
||||
|
||||
inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) {
|
||||
@@ -131,15 +127,13 @@ inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) {
|
||||
template <typename Fn>
|
||||
decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) {
|
||||
if (proxy->Adapter().type() == typeid(std::shared_ptr<CSRArrayAdapter>)) {
|
||||
auto value =
|
||||
dmlc::get<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
|
||||
auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
|
||||
if (type_error) {
|
||||
*type_error = false;
|
||||
}
|
||||
return fn(value);
|
||||
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<ArrayAdapter>)) {
|
||||
auto value = dmlc::get<std::shared_ptr<ArrayAdapter>>(
|
||||
proxy->Adapter())->Value();
|
||||
auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter())->Value();
|
||||
if (type_error) {
|
||||
*type_error = false;
|
||||
}
|
||||
@@ -154,6 +148,5 @@ decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_
|
||||
decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
#endif // XGBOOST_DATA_PROXY_DMATRIX_H_
|
||||
|
||||
Reference in New Issue
Block a user