Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* \file gbtree.cc
|
||||
* \brief gradient boosted tree implementation.
|
||||
* \author Tianqi Chen
|
||||
@@ -265,15 +265,34 @@ class GBTree : public GradientBooster {
|
||||
bool training,
|
||||
unsigned ntree_limit) override;
|
||||
|
||||
void InplacePredict(dmlc::any const &x, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t layer_begin,
|
||||
unsigned layer_end) const override {
|
||||
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t layer_begin, unsigned layer_end) const override {
|
||||
CHECK(configured_);
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
this->GetPredictor()->InplacePredict(x, model_, missing, out_preds,
|
||||
tree_begin, tree_end);
|
||||
std::tie(tree_begin, tree_end) =
|
||||
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
std::vector<Predictor const *> predictors{
|
||||
cpu_predictor_.get(),
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
gpu_predictor_.get()
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
};
|
||||
StringView msg{"Unsupported data type for inplace predict."};
|
||||
if (tparam_.predictor == PredictorType::kAuto) {
|
||||
// Try both predictor implementations
|
||||
for (auto const &p : predictors) {
|
||||
if (p && p->InplacePredict(x, p_m, model_, missing, out_preds,
|
||||
tree_begin, tree_end)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << msg;
|
||||
} else {
|
||||
bool success = this->GetPredictor()->InplacePredict(
|
||||
x, p_m, model_, missing, out_preds, tree_begin, tree_end);
|
||||
CHECK(success) << msg;
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
|
||||
Reference in New Issue
Block a user