Enhance inplace prediction. (#6653)

* Accept array interface for csr and array.
* Accept an optional proxy dmatrix for metainfo.

This constructs an explicit `_ProxyDMatrix` type in Python.

* Remove unused doc.
* Add strict output.
This commit is contained in:
Jiaming Yuan
2021-02-02 11:41:46 +08:00
committed by GitHub
parent 87ab1ad607
commit 411592a347
22 changed files with 955 additions and 530 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2020 by Contributors
* Copyright 2014-2021 by Contributors
* \file learner.cc
* \brief Implementation of learning algorithm.
* \author Tianqi Chen
@@ -1110,23 +1110,30 @@ class LearnerImpl : public LearnerIO {
CHECK(!this->need_configuration_);
return this->gbm_->BoostedRounds();
}
uint32_t Groups() const override {
CHECK(!this->need_configuration_);
return this->learner_model_param_.num_output_group;
}
XGBAPIThreadLocalEntry& GetThreadLocal() const override {
return (*LearnerAPIThreadLocalStore::Get())[this];
}
void InplacePredict(dmlc::any const &x, std::string const &type,
float missing, HostDeviceVector<bst_float> **out_preds,
uint32_t layer_begin, uint32_t layer_end) override {
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
PredictionType type, float missing,
HostDeviceVector<bst_float> **out_preds,
uint32_t iteration_begin,
uint32_t iteration_end) override {
this->Configure();
auto& out_predictions = this->GetThreadLocal().prediction_entry;
this->gbm_->InplacePredict(x, missing, &out_predictions, layer_begin,
layer_end);
if (type == "value") {
this->gbm_->InplacePredict(x, p_m, missing, &out_predictions,
iteration_begin, iteration_end);
if (type == PredictionType::kValue) {
obj_->PredTransform(&out_predictions.predictions);
} else if (type == "margin") {
} else if (type == PredictionType::kMargin) {
// do nothing
} else {
LOG(FATAL) << "Unsupported prediction type:" << type;
LOG(FATAL) << "Unsupported prediction type:" << static_cast<int>(type);
}
*out_preds = &out_predictions.predictions;
}