Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright by Contributors 2017-2020
|
||||
* Copyright by Contributors 2017-2021
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/any.h>
|
||||
@@ -287,7 +287,7 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
void DispatchedInplacePredict(dmlc::any const &x,
|
||||
void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
@@ -295,33 +295,44 @@ class CPUPredictor : public Predictor {
|
||||
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
|
||||
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
MetaInfo info;
|
||||
info.num_col_ = m->NumColumns();
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
std::vector<Entry> workspace(info.num_col_ * 8 * threads);
|
||||
if (p_m) {
|
||||
p_m->Info().num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model);
|
||||
} else {
|
||||
MetaInfo info;
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
}
|
||||
std::vector<Entry> workspace(m->NumColumns() * 8 * threads);
|
||||
auto &predictions = out_preds->predictions.HostVector();
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
InitThreadTemp(threads*kBlockOfRowsSize, model.learner_model_param->num_feature,
|
||||
&thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>,
|
||||
kBlockOfRowsSize>(AdapterView<Adapter>(
|
||||
m.get(), missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
InitThreadTemp(threads * kBlockOfRowsSize,
|
||||
model.learner_model_param->num_feature, &thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockOfRowsSize>(
|
||||
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, unsigned tree_end) const override {
|
||||
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds, uint32_t tree_begin,
|
||||
unsigned tree_end) const override {
|
||||
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::DenseAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::ArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::ArrayAdapter> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRArrayAdapter> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
LOG(FATAL) << "Data type is not supported by CPU Predictor.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
|
||||
Reference in New Issue
Block a user