Rename and extract Context. (#8528)

* Rename `GenericParameter` to `Context`.
* Rename header file to reflect the change.
* Rename all references.
This commit is contained in:
Jiaming Yuan
2022-12-07 04:58:54 +08:00
committed by GitHub
parent 05fc6f3ca9
commit 3e26107a9c
105 changed files with 548 additions and 574 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright by Contributors 2017-2021
* Copyright by XGBoost Contributors 2017-2022
*/
#include <dmlc/any.h>
#include <dmlc/omp.h>
@@ -351,8 +351,7 @@ class CPUPredictor : public Predictor {
}
public:
explicit CPUPredictor(GenericParameter const* generic_param) :
Predictor::Predictor{generic_param} {}
explicit CPUPredictor(Context const *ctx) : Predictor::Predictor{ctx} {}
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
const gbm::GBTreeModel &model, uint32_t tree_begin,
@@ -614,9 +613,7 @@ class CPUPredictor : public Predictor {
};
XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor")
.describe("Make predictions using CPU.")
.set_body([](GenericParameter const* generic_param) {
return new CPUPredictor(generic_param);
});
.describe("Make predictions using CPU.")
.set_body([](Context const *ctx) { return new CPUPredictor(ctx); });
} // namespace predictor
} // namespace xgboost

View File

@@ -723,8 +723,7 @@ class GPUPredictor : public xgboost::Predictor {
}
public:
explicit GPUPredictor(GenericParameter const* generic_param) :
Predictor::Predictor{generic_param} {}
explicit GPUPredictor(Context const* ctx) : Predictor::Predictor{ctx} {}
~GPUPredictor() override {
if (ctx_->gpu_id >= 0 && ctx_->gpu_id < common::AllVisibleGPUs()) {
@@ -1026,10 +1025,8 @@ class GPUPredictor : public xgboost::Predictor {
};
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
.describe("Make predictions using GPU.")
.set_body([](GenericParameter const* generic_param) {
return new GPUPredictor(generic_param);
});
.describe("Make predictions using GPU.")
.set_body([](Context const* ctx) { return new GPUPredictor(ctx); });
} // namespace predictor
} // namespace xgboost

View File

@@ -1,14 +1,15 @@
/*!
* Copyright 2017-2021 by Contributors
*/
#include "xgboost/predictor.h"
#include <dmlc/registry.h>
#include <mutex>
#include "xgboost/predictor.h"
#include "xgboost/data.h"
#include "xgboost/generic_parameters.h"
#include "../gbm/gbtree.h"
#include "xgboost/context.h"
#include "xgboost/data.h"
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
@@ -30,7 +31,7 @@ void PredictionContainer::ClearExpiredEntries() {
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
this->ClearExpiredEntries();
container_[m.get()].ref = m;
if (device != GenericParameter::kCpuId) {
if (device != Context::kCpuId) {
container_[m.get()].predictions.SetDevice(device);
}
return container_[m.get()];
@@ -51,13 +52,12 @@ decltype(PredictionContainer::container_) const& PredictionContainer::Container(
void Predictor::Configure(
const std::vector<std::pair<std::string, std::string>>&) {
}
Predictor* Predictor::Create(
std::string const& name, GenericParameter const* generic_param) {
Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown predictor type " << name;
}
auto p_predictor = (e->body)(generic_param);
auto p_predictor = (e->body)(ctx);
return p_predictor;
}