Pass infomation about objective to tree methods. (#7385)

* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
Jiaming Yuan
2021-11-04 01:52:44 +08:00
committed by GitHub
parent ccdabe4512
commit 4100827971
28 changed files with 178 additions and 69 deletions

View File

@@ -11,15 +11,16 @@
#include <dmlc/any.h>
#include <xgboost/base.h>
#include <xgboost/feature_map.h>
#include <xgboost/predictor.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/model.h>
#include <xgboost/predictor.h>
#include <xgboost/task.h>
#include <utility>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
namespace xgboost {
@@ -307,11 +308,13 @@ struct LearnerModelParam {
uint32_t num_feature { 0 };
/* \brief number of classes, if it is multi-class classification */
uint32_t num_output_group { 0 };
/* \brief Current task, determined by objective. */
ObjInfo task{ObjInfo::kRegression};
LearnerModelParam() = default;
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
// this one as an immutable copy.
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin);
LearnerModelParam(LearnerModelParamLegacy const& user_param, float base_margin, ObjInfo t);
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
bool Initialized() const { return num_feature != 0; }
};

View File

@@ -13,6 +13,7 @@
#include <xgboost/model.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/task.h>
#include <vector>
#include <utility>
@@ -72,6 +73,11 @@ class ObjFunction : public Configurable {
virtual bst_float ProbToMargin(bst_float base_score) const {
return base_score;
}
/*!
* \brief Return task of this objective.
*/
virtual struct ObjInfo Task() const = 0;
/*!
* \brief Create an objective function according to name.
* \param tparam Generic parameters.

39
include/xgboost/task.h Normal file
View File

@@ -0,0 +1,39 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#ifndef XGBOOST_TASK_H_
#define XGBOOST_TASK_H_
#include <cinttypes>
namespace xgboost {
/*!
* \brief A struct returned by objective, which determines task at hand. The struct is
* not used by any algorithm yet, only for future development like categorical
* split.
*
* The task field is useful for tree split finding, also for some metrics like auc.
* Lastly, knowing whether hessian is constant can allow some optimizations like skipping
* the quantile sketching.
*
* This struct should not be serialized since it can be recovered from objective function,
* hence it doesn't need to be stable.
*/
struct ObjInfo {
// What kind of problem are we trying to solve
enum Task : uint8_t {
kRegression = 0,
kBinary = 1,
kClassification = 2,
kSurvival = 3,
kRanking = 4,
kOther = 5,
} task;
// Does the objective have constant hessian value?
bool const_hess{false};
explicit ObjInfo(Task t) : task{t} {}
ObjInfo(Task t, bool khess) : const_hess{khess} {}
};
} // namespace xgboost
#endif // XGBOOST_TASK_H_

View File

@@ -11,16 +11,17 @@
#include <dmlc/registry.h>
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <xgboost/tree_model.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/model.h>
#include <xgboost/linalg.h>
#include <xgboost/model.h>
#include <xgboost/task.h>
#include <xgboost/tree_model.h>
#include <functional>
#include <vector>
#include <utility>
#include <string>
#include <utility>
#include <vector>
namespace xgboost {
@@ -83,7 +84,7 @@ class TreeUpdater : public Configurable {
* \param name Name of the tree updater.
* \param tparam A global runtime parameter
*/
static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam);
static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam, ObjInfo task);
};
/*!
@@ -91,8 +92,7 @@ class TreeUpdater : public Configurable {
*/
struct TreeUpdaterReg
: public dmlc::FunctionRegEntryBase<TreeUpdaterReg,
std::function<TreeUpdater* ()> > {
};
std::function<TreeUpdater*(ObjInfo task)> > {};
/*!
* \brief Macro to register tree updater.