Pass infomation about objective to tree methods. (#7385)

* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
Jiaming Yuan
2021-11-04 01:52:44 +08:00
committed by GitHub
parent ccdabe4512
commit 4100827971
28 changed files with 178 additions and 69 deletions

View File

@@ -38,6 +38,8 @@ class AFTObj : public ObjFunction {
param_.UpdateAllowUnknown(args);
}
ObjInfo Task() const override { return {ObjInfo::kSurvival, false}; }
template <typename Distribution>
void GetGradientImpl(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,

View File

@@ -27,6 +27,8 @@ class HingeObj : public ObjFunction {
void Configure(
const std::vector<std::pair<std::string, std::string> > &args) override {}
ObjInfo Task() const override { return {ObjInfo::kRegression, false}; }
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
int iter,

View File

@@ -45,6 +45,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
void Configure(Args const& args) override {
param_.UpdateAllowUnknown(args);
}
ObjInfo Task() const override { return {ObjInfo::kClassification, false}; }
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,

View File

@@ -754,6 +754,8 @@ class LambdaRankObj : public ObjFunction {
param_.UpdateAllowUnknown(args);
}
ObjInfo Task() const override { return {ObjInfo::kRanking, false}; }
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,

View File

@@ -7,6 +7,8 @@
#include <dmlc/omp.h>
#include <xgboost/logging.h>
#include <algorithm>
#include "xgboost/task.h"
#include "../common/math.h"
namespace xgboost {
@@ -36,6 +38,7 @@ struct LinearSquareLoss {
static const char* DefaultEvalMetric() { return "rmse"; }
static const char* Name() { return "reg:squarederror"; }
static ObjInfo Info() { return {ObjInfo::kRegression, true}; }
};
struct SquaredLogError {
@@ -61,6 +64,8 @@ struct SquaredLogError {
static const char* DefaultEvalMetric() { return "rmsle"; }
static const char* Name() { return "reg:squaredlogerror"; }
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
};
// logistic loss for probability regression task
@@ -96,6 +101,8 @@ struct LogisticRegression {
static const char* DefaultEvalMetric() { return "rmse"; }
static const char* Name() { return "reg:logistic"; }
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
};
struct PseudoHuberError {
@@ -127,12 +134,14 @@ struct PseudoHuberError {
static const char* Name() {
return "reg:pseudohubererror";
}
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
};
// logistic loss for binary classification task
struct LogisticClassification : public LogisticRegression {
static const char* DefaultEvalMetric() { return "logloss"; }
static const char* Name() { return "binary:logistic"; }
static ObjInfo Info() { return {ObjInfo::kBinary, false}; }
};
// logistic loss, but predict un-transformed margin
@@ -168,6 +177,8 @@ struct LogisticRaw : public LogisticRegression {
static const char* DefaultEvalMetric() { return "logloss"; }
static const char* Name() { return "binary:logitraw"; }
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
};
} // namespace obj

View File

@@ -52,6 +52,10 @@ class RegLossObj : public ObjFunction {
param_.UpdateAllowUnknown(args);
}
struct ObjInfo Task() const override {
return Loss::Info();
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info, int,
HostDeviceVector<GradientPair>* out_gpair) override {
@@ -207,6 +211,10 @@ class PoissonRegression : public ObjFunction {
param_.UpdateAllowUnknown(args);
}
struct ObjInfo Task() const override {
return {ObjInfo::kRegression, false};
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info, int,
HostDeviceVector<GradientPair> *out_gpair) override {
@@ -298,6 +306,10 @@ class CoxRegression : public ObjFunction {
void Configure(
const std::vector<std::pair<std::string, std::string> >&) override {}
struct ObjInfo Task() const override {
return {ObjInfo::kRegression, false};
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info, int,
HostDeviceVector<GradientPair> *out_gpair) override {
@@ -395,6 +407,10 @@ class GammaRegression : public ObjFunction {
void Configure(
const std::vector<std::pair<std::string, std::string> >&) override {}
struct ObjInfo Task() const override {
return {ObjInfo::kRegression, false};
}
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info, int,
HostDeviceVector<GradientPair> *out_gpair) override {
@@ -491,6 +507,10 @@ class TweedieRegression : public ObjFunction {
metric_ = os.str();
}
struct ObjInfo Task() const override {
return {ObjInfo::kRegression, false};
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info, int,
HostDeviceVector<GradientPair> *out_gpair) override {