Simplify inplace-predict. (#7910)
Pass the `X` as part of Proxy DMatrix instead of an independent `dmlc::any`.
This commit is contained in:
@@ -1,27 +1,27 @@
|
||||
/*!
|
||||
* Copyright by Contributors 2017-2021
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/any.h>
|
||||
#include <dmlc/omp.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <mutex>
|
||||
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
#include "../gbm/gbtree_model.h"
|
||||
#include "predict_fn.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/predictor.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "predict_fn.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../gbm/gbtree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
@@ -327,22 +327,24 @@ class CPUPredictor : public Predictor {
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp, n_threads);
|
||||
}
|
||||
|
||||
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
bool InplacePredict(std::shared_ptr<DMatrix> p_m, const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds, uint32_t tree_begin,
|
||||
unsigned tree_end) const override {
|
||||
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
||||
CHECK(proxy)<< "Inplace predict accepts only DMatrixProxy as input.";
|
||||
auto x = proxy->Adapter();
|
||||
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::DenseAdapter, kBlockOfRowsSize>(
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRAdapter, 1>(
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
this->DispatchedInplacePredict<data::CSRAdapter, 1>(x, p_m, model, missing, out_preds,
|
||||
tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::ArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::ArrayAdapter, kBlockOfRowsSize> (
|
||||
this->DispatchedInplacePredict<data::ArrayAdapter, kBlockOfRowsSize>(
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRArrayAdapter, 1> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
this->DispatchedInplacePredict<data::CSRArrayAdapter, 1>(x, p_m, model, missing, out_preds,
|
||||
tree_begin, tree_end);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user