Pass infomation about objective to tree methods. (#7385)
* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
@@ -11,15 +11,16 @@
|
||||
#include <dmlc/any.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <xgboost/task.h>
|
||||
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace xgboost {
|
||||
@@ -307,11 +308,13 @@ struct LearnerModelParam {
|
||||
uint32_t num_feature { 0 };
|
||||
/* \brief number of classes, if it is multi-class classification */
|
||||
uint32_t num_output_group { 0 };
|
||||
/* \brief Current task, determined by objective. */
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
|
||||
LearnerModelParam() = default;
|
||||
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
|
||||
// this one as an immutable copy.
|
||||
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin);
|
||||
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, ObjInfo t);
|
||||
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
|
||||
bool Initialized() const { return num_feature != 0; }
|
||||
};
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/task.h>
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
@@ -72,6 +73,11 @@ class ObjFunction : public Configurable {
|
||||
virtual bst_float ProbToMargin(bst_float base_score) const {
|
||||
return base_score;
|
||||
}
|
||||
/*!
|
||||
* \brief Return task of this objective.
|
||||
*/
|
||||
virtual struct ObjInfo Task() const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Create an objective function according to name.
|
||||
* \param tparam Generic parameters.
|
||||
|
||||
39
include/xgboost/task.h
Normal file
39
include/xgboost/task.h
Normal file
@@ -0,0 +1,39 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TASK_H_
|
||||
#define XGBOOST_TASK_H_
|
||||
|
||||
#include <cinttypes>
|
||||
|
||||
namespace xgboost {
|
||||
/*!
|
||||
* \brief A struct returned by objective, which determines task at hand. The struct is
|
||||
* not used by any algorithm yet, only for future development like categorical
|
||||
* split.
|
||||
*
|
||||
* The task field is useful for tree split finding, also for some metrics like auc.
|
||||
* Lastly, knowing whether hessian is constant can allow some optimizations like skipping
|
||||
* the quantile sketching.
|
||||
*
|
||||
* This struct should not be serialized since it can be recovered from objective function,
|
||||
* hence it doesn't need to be stable.
|
||||
*/
|
||||
struct ObjInfo {
|
||||
// What kind of problem are we trying to solve
|
||||
enum Task : uint8_t {
|
||||
kRegression = 0,
|
||||
kBinary = 1,
|
||||
kClassification = 2,
|
||||
kSurvival = 3,
|
||||
kRanking = 4,
|
||||
kOther = 5,
|
||||
} task;
|
||||
// Does the objective have constant hessian value?
|
||||
bool const_hess{false};
|
||||
|
||||
explicit ObjInfo(Task t) : task{t} {}
|
||||
ObjInfo(Task t, bool khess) : const_hess{khess} {}
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TASK_H_
|
||||
@@ -11,16 +11,17 @@
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
#include <xgboost/generic_parameters.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/linalg.h>
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/task.h>
|
||||
#include <xgboost/tree_model.h>
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -83,7 +84,7 @@ class TreeUpdater : public Configurable {
|
||||
* \param name Name of the tree updater.
|
||||
* \param tparam A global runtime parameter
|
||||
*/
|
||||
static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam);
|
||||
static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam, ObjInfo task);
|
||||
};
|
||||
|
||||
/*!
|
||||
@@ -91,8 +92,7 @@ class TreeUpdater : public Configurable {
|
||||
*/
|
||||
struct TreeUpdaterReg
|
||||
: public dmlc::FunctionRegEntryBase<TreeUpdaterReg,
|
||||
std::function<TreeUpdater* ()> > {
|
||||
};
|
||||
std::function<TreeUpdater*(ObjInfo task)> > {};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register tree updater.
|
||||
|
||||
Reference in New Issue
Block a user