Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* \file learner.cc
|
||||
* \brief Implementation of learning algorithm.
|
||||
* \author Tianqi Chen
|
||||
@@ -1110,23 +1110,30 @@ class LearnerImpl : public LearnerIO {
|
||||
CHECK(!this->need_configuration_);
|
||||
return this->gbm_->BoostedRounds();
|
||||
}
|
||||
uint32_t Groups() const override {
|
||||
CHECK(!this->need_configuration_);
|
||||
return this->learner_model_param_.num_output_group;
|
||||
}
|
||||
|
||||
XGBAPIThreadLocalEntry& GetThreadLocal() const override {
|
||||
return (*LearnerAPIThreadLocalStore::Get())[this];
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, std::string const &type,
|
||||
float missing, HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t layer_begin, uint32_t layer_end) override {
|
||||
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
PredictionType type, float missing,
|
||||
HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t iteration_begin,
|
||||
uint32_t iteration_end) override {
|
||||
this->Configure();
|
||||
auto& out_predictions = this->GetThreadLocal().prediction_entry;
|
||||
this->gbm_->InplacePredict(x, missing, &out_predictions, layer_begin,
|
||||
layer_end);
|
||||
if (type == "value") {
|
||||
this->gbm_->InplacePredict(x, p_m, missing, &out_predictions,
|
||||
iteration_begin, iteration_end);
|
||||
if (type == PredictionType::kValue) {
|
||||
obj_->PredTransform(&out_predictions.predictions);
|
||||
} else if (type == "margin") {
|
||||
} else if (type == PredictionType::kMargin) {
|
||||
// do nothing
|
||||
} else {
|
||||
LOG(FATAL) << "Unsupported prediction type:" << type;
|
||||
LOG(FATAL) << "Unsupported prediction type:" << static_cast<int>(type);
|
||||
}
|
||||
*out_preds = &out_predictions.predictions;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user