Replace dmlc any with std any. (#8892)
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <thrust/fill.h>
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#include <any> // for any, any_cast
|
||||
#include <memory>
|
||||
|
||||
#include "../common/bitfield.h"
|
||||
@@ -741,13 +742,13 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
template <typename Adapter, typename Loader>
|
||||
void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
void DispatchedInplacePredict(std::any const& x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel& model, float missing,
|
||||
PredictionCacheEntry* out_preds, uint32_t tree_begin,
|
||||
uint32_t tree_end) const {
|
||||
uint32_t const output_groups = model.learner_model_param->num_output_group;
|
||||
|
||||
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
|
||||
auto m = std::any_cast<std::shared_ptr<Adapter>>(x);
|
||||
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
CHECK_EQ(dh::CurrentDevice(), m->DeviceIdx())
|
||||
|
||||
Reference in New Issue
Block a user