Rename and extract Context. (#8528)
* Rename `GenericParameter` to `Context`. * Rename header file to reflect the change. * Rename all references.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright by Contributors 2017-2021
|
||||
* Copyright by XGBoost Contributors 2017-2022
|
||||
*/
|
||||
#include <dmlc/any.h>
|
||||
#include <dmlc/omp.h>
|
||||
@@ -351,8 +351,7 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
|
||||
public:
|
||||
explicit CPUPredictor(GenericParameter const* generic_param) :
|
||||
Predictor::Predictor{generic_param} {}
|
||||
explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {}
|
||||
|
||||
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
|
||||
const gbm::GBTreeModel &model, uint32_t tree_begin,
|
||||
@@ -614,9 +613,7 @@ class CPUPredictor : public Predictor {
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor")
|
||||
.describe("Make predictions using CPU.")
|
||||
.set_body([](GenericParameter const* generic_param) {
|
||||
return new CPUPredictor(generic_param);
|
||||
});
|
||||
.describe("Make predictions using CPU.")
|
||||
.set_body([](Context const *ctx) { return new CPUPredictor(ctx); });
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -723,8 +723,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
public:
|
||||
explicit GPUPredictor(GenericParameter const* generic_param) :
|
||||
Predictor::Predictor{generic_param} {}
|
||||
explicit GPUPredictor(Context const* ctx) : Predictor::Predictor{ctx} {}
|
||||
|
||||
~GPUPredictor() override {
|
||||
if (ctx_->gpu_id >= 0 && ctx_->gpu_id < common::AllVisibleGPUs()) {
|
||||
@@ -1026,10 +1025,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
.describe("Make predictions using GPU.")
|
||||
.set_body([](GenericParameter const* generic_param) {
|
||||
return new GPUPredictor(generic_param);
|
||||
});
|
||||
.describe("Make predictions using GPU.")
|
||||
.set_body([](Context const* ctx) { return new GPUPredictor(ctx); });
|
||||
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
*/
|
||||
#include "xgboost/predictor.h"
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include <mutex>
|
||||
|
||||
#include "xgboost/predictor.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/generic_parameters.h"
|
||||
|
||||
#include "../gbm/gbtree.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
|
||||
@@ -30,7 +31,7 @@ void PredictionContainer::ClearExpiredEntries() {
|
||||
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
||||
this->ClearExpiredEntries();
|
||||
container_[m.get()].ref = m;
|
||||
if (device != GenericParameter::kCpuId) {
|
||||
if (device != Context::kCpuId) {
|
||||
container_[m.get()].predictions.SetDevice(device);
|
||||
}
|
||||
return container_[m.get()];
|
||||
@@ -51,13 +52,12 @@ decltype(PredictionContainer::container_) const& PredictionContainer::Container(
|
||||
void Predictor::Configure(
|
||||
const std::vector<std::pair<std::string, std::string>>&) {
|
||||
}
|
||||
Predictor* Predictor::Create(
|
||||
std::string const& name, GenericParameter const* generic_param) {
|
||||
Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
|
||||
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown predictor type " << name;
|
||||
}
|
||||
auto p_predictor = (e->body)(generic_param);
|
||||
auto p_predictor = (e->body)(ctx);
|
||||
return p_predictor;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user