Rename and extract Context. (#8528)
* Rename `GenericParameter` to `Context`. * Rename header file to reflect the change. * Rename all references.
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
/*!
|
||||
* Copyright 2014-2019 by Contributors
|
||||
* \file generic_parameters.h
|
||||
* Copyright 2014-2022 by Contributors
|
||||
* \file context.h
|
||||
*/
|
||||
#ifndef XGBOOST_GENERIC_PARAMETERS_H_
|
||||
#define XGBOOST_GENERIC_PARAMETERS_H_
|
||||
#ifndef XGBOOST_CONTEXT_H_
|
||||
#define XGBOOST_CONTEXT_H_
|
||||
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/parameter.h>
|
||||
@@ -12,31 +12,31 @@
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||
struct Context : public XGBoostParameter<Context> {
|
||||
private:
|
||||
// cached value for CFS CPU limit. (used in containerized env)
|
||||
int32_t cfs_cpu_count_; // NOLINT
|
||||
std::int32_t cfs_cpu_count_; // NOLINT
|
||||
|
||||
public:
|
||||
// Constant representing the device ID of CPU.
|
||||
static int32_t constexpr kCpuId = -1;
|
||||
static int64_t constexpr kDefaultSeed = 0;
|
||||
static std::int32_t constexpr kCpuId = -1;
|
||||
static std::int64_t constexpr kDefaultSeed = 0;
|
||||
|
||||
public:
|
||||
GenericParameter();
|
||||
Context();
|
||||
|
||||
// stored random seed
|
||||
int64_t seed { kDefaultSeed };
|
||||
std::int64_t seed{kDefaultSeed};
|
||||
// whether seed the PRNG each iteration
|
||||
bool seed_per_iteration{false};
|
||||
// number of threads to use if OpenMP is enabled
|
||||
// if equals 0, use system default
|
||||
int nthread{0};
|
||||
std::int32_t nthread{0};
|
||||
// primary device, -1 means no gpu.
|
||||
int gpu_id{kCpuId};
|
||||
std::int32_t gpu_id{kCpuId};
|
||||
// fail when gpu_id is invalid
|
||||
bool fail_on_invalid_gpu_id {false};
|
||||
bool validate_parameters {false};
|
||||
bool fail_on_invalid_gpu_id{false};
|
||||
bool validate_parameters{false};
|
||||
|
||||
/*!
|
||||
* \brief Configure the parameter `gpu_id'.
|
||||
@@ -47,26 +47,25 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||
/*!
|
||||
* Return automatically chosen threads.
|
||||
*/
|
||||
int32_t Threads() const;
|
||||
std::int32_t Threads() const;
|
||||
|
||||
bool IsCPU() const { return gpu_id == kCpuId; }
|
||||
bool IsCUDA() const { return !IsCPU(); }
|
||||
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GenericParameter) {
|
||||
DMLC_DECLARE_FIELD(seed).set_default(kDefaultSeed).describe(
|
||||
"Random number seed during training.");
|
||||
DMLC_DECLARE_PARAMETER(Context) {
|
||||
DMLC_DECLARE_FIELD(seed)
|
||||
.set_default(kDefaultSeed)
|
||||
.describe("Random number seed during training.");
|
||||
DMLC_DECLARE_ALIAS(seed, random_state);
|
||||
DMLC_DECLARE_FIELD(seed_per_iteration)
|
||||
.set_default(false)
|
||||
.describe("Seed PRNG determnisticly via iterator number.");
|
||||
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
|
||||
"Number of threads to use.");
|
||||
DMLC_DECLARE_FIELD(nthread).set_default(0).describe("Number of threads to use.");
|
||||
DMLC_DECLARE_ALIAS(nthread, n_jobs);
|
||||
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_default(-1)
|
||||
.set_lower_bound(-1)
|
||||
.describe("The primary GPU device ordinal.");
|
||||
DMLC_DECLARE_FIELD(gpu_id).set_default(-1).set_lower_bound(-1).describe(
|
||||
"The primary GPU device ordinal.");
|
||||
DMLC_DECLARE_FIELD(fail_on_invalid_gpu_id)
|
||||
.set_default(false)
|
||||
.describe("Fail with error when gpu_id is invalid.");
|
||||
@@ -75,8 +74,6 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||
.describe("Enable checking whether parameters are used or not.");
|
||||
}
|
||||
};
|
||||
|
||||
using Context = GenericParameter;
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_GENERIC_PARAMETERS_H_
|
||||
#endif // XGBOOST_CONTEXT_H_
|
||||
@@ -11,7 +11,6 @@
|
||||
#include <dmlc/data.h>
|
||||
#include <dmlc/serializer.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/linalg.h>
|
||||
#include <xgboost/span.h>
|
||||
@@ -28,6 +27,7 @@
|
||||
namespace xgboost {
|
||||
// forward declare dmatrix.
|
||||
class DMatrix;
|
||||
struct Context;
|
||||
|
||||
/*! \brief data type accepted by xgboost interface */
|
||||
enum class DataType : uint8_t {
|
||||
|
||||
@@ -28,7 +28,7 @@ class Json;
|
||||
class FeatureMap;
|
||||
class ObjFunction;
|
||||
|
||||
struct GenericParameter;
|
||||
struct Context;
|
||||
struct LearnerModelParam;
|
||||
struct PredictionCacheEntry;
|
||||
class PredictionContainer;
|
||||
@@ -38,8 +38,8 @@ class PredictionContainer;
|
||||
*/
|
||||
class GradientBooster : public Model, public Configurable {
|
||||
protected:
|
||||
GenericParameter const* ctx_;
|
||||
explicit GradientBooster(GenericParameter const* ctx) : ctx_{ctx} {}
|
||||
Context const* ctx_;
|
||||
explicit GradientBooster(Context const* ctx) : ctx_{ctx} {}
|
||||
|
||||
public:
|
||||
/*! \brief virtual destructor */
|
||||
@@ -193,10 +193,8 @@ class GradientBooster : public Model, public Configurable {
|
||||
* \param learner_model_param pointer to global model parameters
|
||||
* \return The created booster.
|
||||
*/
|
||||
static GradientBooster* Create(
|
||||
const std::string& name,
|
||||
GenericParameter const* generic_param,
|
||||
LearnerModelParam const* learner_model_param);
|
||||
static GradientBooster* Create(const std::string& name, Context const* ctx,
|
||||
LearnerModelParam const* learner_model_param);
|
||||
};
|
||||
|
||||
/*!
|
||||
@@ -206,7 +204,7 @@ struct GradientBoosterReg
|
||||
: public dmlc::FunctionRegEntryBase<
|
||||
GradientBoosterReg,
|
||||
std::function<GradientBooster*(LearnerModelParam const* learner_model_param,
|
||||
GenericParameter const* ctx)> > {};
|
||||
Context const* ctx)> > {};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register gradient booster.
|
||||
|
||||
@@ -9,8 +9,8 @@
|
||||
#define XGBOOST_LEARNER_H_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/context.h> // Context
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/generic_parameters.h> // Context
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/predictor.h>
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
#include <dmlc/endian.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/context.h> // fixme(jiamingy): Remove the dependency on this header.
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/span.h>
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/model.h>
|
||||
|
||||
@@ -19,6 +18,7 @@
|
||||
namespace xgboost {
|
||||
|
||||
class Json;
|
||||
struct Context;
|
||||
|
||||
namespace gbm {
|
||||
class GBLinearModel;
|
||||
@@ -29,7 +29,7 @@ class GBLinearModel;
|
||||
*/
|
||||
class LinearUpdater : public Configurable {
|
||||
protected:
|
||||
GenericParameter const* ctx_;
|
||||
Context const* ctx_;
|
||||
|
||||
public:
|
||||
/*! \brief virtual destructor */
|
||||
@@ -57,7 +57,7 @@ class LinearUpdater : public Configurable {
|
||||
* \brief Create a linear updater given name
|
||||
* \param name Name of the linear updater.
|
||||
*/
|
||||
static LinearUpdater* Create(const std::string& name, GenericParameter const*);
|
||||
static LinearUpdater* Create(const std::string& name, Context const*);
|
||||
};
|
||||
|
||||
/*!
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
@@ -20,13 +19,15 @@
|
||||
#include <utility>
|
||||
|
||||
namespace xgboost {
|
||||
struct Context;
|
||||
|
||||
/*!
|
||||
* \brief interface of evaluation metric used to evaluate model performance.
|
||||
* This has nothing to do with training, but merely act as evaluation purpose.
|
||||
*/
|
||||
class Metric : public Configurable {
|
||||
protected:
|
||||
GenericParameter const* tparam_;
|
||||
Context const* tparam_;
|
||||
|
||||
public:
|
||||
/*!
|
||||
@@ -68,10 +69,10 @@ class Metric : public Configurable {
|
||||
* \param name name of the metric.
|
||||
* name can be in form metric[@]param and the name will be matched in the
|
||||
* registry.
|
||||
* \param tparam A global generic parameter
|
||||
* \param ctx A global context
|
||||
* \return the created metric.
|
||||
*/
|
||||
static Metric* Create(const std::string& name, GenericParameter const* tparam);
|
||||
static Metric* Create(const std::string& name, Context const* ctx);
|
||||
};
|
||||
|
||||
/*!
|
||||
|
||||
@@ -10,19 +10,19 @@
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/task.h>
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class RegTree;
|
||||
struct Context;
|
||||
|
||||
/*! \brief interface of objective function */
|
||||
class ObjFunction : public Configurable {
|
||||
@@ -120,10 +120,10 @@ class ObjFunction : public Configurable {
|
||||
|
||||
/*!
|
||||
* \brief Create an objective function according to name.
|
||||
* \param tparam Generic parameters.
|
||||
* \param ctx Pointer to runtime parameters.
|
||||
* \param name Name of the objective.
|
||||
*/
|
||||
static ObjFunction* Create(const std::string& name, GenericParameter const* tparam);
|
||||
static ObjFunction* Create(const std::string& name, Context const* ctx);
|
||||
};
|
||||
|
||||
/*!
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* Copyright 2017-2022 by Contributors
|
||||
* \file predictor.h
|
||||
* \brief Interface of predictor,
|
||||
* performs predictions for a gradient booster.
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
|
||||
// Forward declarations
|
||||
namespace xgboost {
|
||||
@@ -73,7 +73,7 @@ class PredictionContainer {
|
||||
*
|
||||
* \param m shared pointer to the DMatrix that needs to be cached.
|
||||
* \param device Which device should the cache be allocated on. Pass
|
||||
* GenericParameter::kCpuId for CPU or positive integer for GPU id.
|
||||
* Context::kCpuId for CPU or positive integer for GPU id.
|
||||
*
|
||||
* \return the cache entry for passed in DMatrix, either an existing cache or newly
|
||||
* created.
|
||||
@@ -218,19 +218,17 @@ class Predictor {
|
||||
/**
|
||||
* \brief Creates a new Predictor*.
|
||||
*
|
||||
* \param name Name of the predictor.
|
||||
* \param generic_param Pointer to runtime parameters.
|
||||
* \param name Name of the predictor.
|
||||
* \param ctx Pointer to runtime parameters.
|
||||
*/
|
||||
static Predictor* Create(
|
||||
std::string const& name, GenericParameter const* generic_param);
|
||||
static Predictor* Create(std::string const& name, Context const* ctx);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Registry entry for predictor.
|
||||
*/
|
||||
struct PredictorReg
|
||||
: public dmlc::FunctionRegEntryBase<
|
||||
PredictorReg, std::function<Predictor*(GenericParameter const*)>> {};
|
||||
: public dmlc::FunctionRegEntryBase<PredictorReg, std::function<Predictor*(Context const*)>> {};
|
||||
|
||||
#define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \
|
||||
static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \
|
||||
|
||||
@@ -10,8 +10,8 @@
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/linalg.h>
|
||||
#include <xgboost/model.h>
|
||||
@@ -26,16 +26,17 @@
|
||||
namespace xgboost {
|
||||
|
||||
class Json;
|
||||
struct Context;
|
||||
|
||||
/*!
|
||||
* \brief interface of tree update module, that performs update of a tree.
|
||||
*/
|
||||
class TreeUpdater : public Configurable {
|
||||
protected:
|
||||
GenericParameter const* ctx_ = nullptr;
|
||||
Context const* ctx_ = nullptr;
|
||||
|
||||
public:
|
||||
explicit TreeUpdater(const GenericParameter* ctx) : ctx_(ctx) {}
|
||||
explicit TreeUpdater(const Context* ctx) : ctx_(ctx) {}
|
||||
/*! \brief virtual destructor */
|
||||
~TreeUpdater() override = default;
|
||||
/*!
|
||||
@@ -90,9 +91,9 @@ class TreeUpdater : public Configurable {
|
||||
/*!
|
||||
* \brief Create a tree updater given name
|
||||
* \param name Name of the tree updater.
|
||||
* \param tparam A global runtime parameter
|
||||
* \param ctx A global runtime parameter
|
||||
*/
|
||||
static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam, ObjInfo task);
|
||||
static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo task);
|
||||
};
|
||||
|
||||
/*!
|
||||
@@ -100,8 +101,7 @@ class TreeUpdater : public Configurable {
|
||||
*/
|
||||
struct TreeUpdaterReg
|
||||
: public dmlc::FunctionRegEntryBase<
|
||||
TreeUpdaterReg,
|
||||
std::function<TreeUpdater*(GenericParameter const* tparam, ObjInfo task)> > {};
|
||||
TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo task)>> {};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register tree updater.
|
||||
|
||||
Reference in New Issue
Block a user