/*! * Copyright 2017-2021 by Contributors */ #include #include #include "xgboost/predictor.h" #include "xgboost/data.h" #include "xgboost/generic_parameters.h" #include "../gbm/gbtree.h" namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg); } // namespace dmlc namespace xgboost { void PredictionContainer::ClearExpiredEntries() { std::vector expired; for (auto& kv : container_) { if (kv.second.ref.expired()) { expired.emplace_back(kv.first); } } for (auto const& ptr : expired) { container_.erase(ptr); } } PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr m, int32_t device) { this->ClearExpiredEntries(); container_[m.get()].ref = m; if (device != GenericParameter::kCpuId) { container_[m.get()].predictions.SetDevice(device); } return container_[m.get()]; } PredictionCacheEntry &PredictionContainer::Entry(DMatrix *m) { CHECK(container_.find(m) != container_.cend()); CHECK(container_.at(m).ref.lock()) << "[Internal error]: DMatrix: " << m << " has expired."; return container_.at(m); } decltype(PredictionContainer::container_) const& PredictionContainer::Container() { this->ClearExpiredEntries(); return container_; } void Predictor::Configure( const std::vector>&) { } Predictor* Predictor::Create( std::string const& name, GenericParameter const* generic_param) { auto* e = ::dmlc::Registry::Get()->Find(name); if (e == nullptr) { LOG(FATAL) << "Unknown predictor type " << name; } auto p_predictor = (e->body)(generic_param); return p_predictor; } template void ValidateBaseMarginShape(linalg::Tensor const& margin, bst_row_t n_samples, bst_group_t n_groups) { // FIXME: Bindings other than Python doesn't have shape. std::string expected{"Invalid shape of base_margin. Expected: (" + std::to_string(n_samples) + ", " + std::to_string(n_groups) + ")"}; CHECK_EQ(margin.Shape(0), n_samples) << expected; CHECK_EQ(margin.Shape(1), n_groups) << expected; } void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector* out_preds, const gbm::GBTreeModel& model) const { CHECK_NE(model.learner_model_param->num_output_group, 0); size_t n_classes = model.learner_model_param->num_output_group; size_t n = n_classes * info.num_row_; const HostDeviceVector* base_margin = info.base_margin_.Data(); if (ctx_->gpu_id >= 0) { out_preds->SetDevice(ctx_->gpu_id); } if (base_margin->Size() != 0) { out_preds->Resize(n); ValidateBaseMarginShape(info.base_margin_, info.num_row_, n_classes); out_preds->Copy(*base_margin); } else { out_preds->Resize(n); // cannot rely on the Resize to fill as it might skip if the size is already correct. out_preds->Fill(model.learner_model_param->base_score); } } } // namespace xgboost namespace xgboost { namespace predictor { // List of files that will be force linked in static links. #ifdef XGBOOST_USE_CUDA DMLC_REGISTRY_LINK_TAG(gpu_predictor); #endif // XGBOOST_USE_CUDA DMLC_REGISTRY_LINK_TAG(cpu_predictor); } // namespace predictor } // namespace xgboost