Pass infomation about objective to tree methods. (#7385)
* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
@@ -38,6 +38,8 @@ class AFTObj : public ObjFunction {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
|
||||
ObjInfo Task() const override { return {ObjInfo::kSurvival, false}; }
|
||||
|
||||
template <typename Distribution>
|
||||
void GetGradientImpl(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
|
||||
@@ -27,6 +27,8 @@ class HingeObj : public ObjFunction {
|
||||
void Configure(
|
||||
const std::vector<std::pair<std::string, std::string> > &args) override {}
|
||||
|
||||
ObjInfo Task() const override { return {ObjInfo::kRegression, false}; }
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
|
||||
@@ -45,6 +45,9 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
void Configure(Args const& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
|
||||
ObjInfo Task() const override { return {ObjInfo::kClassification, false}; }
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
|
||||
@@ -754,6 +754,8 @@ class LambdaRankObj : public ObjFunction {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
|
||||
ObjInfo Task() const override { return {ObjInfo::kRanking, false}; }
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <algorithm>
|
||||
|
||||
#include "xgboost/task.h"
|
||||
#include "../common/math.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -36,6 +38,7 @@ struct LinearSquareLoss {
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
|
||||
static const char* Name() { return "reg:squarederror"; }
|
||||
static ObjInfo Info() { return {ObjInfo::kRegression, true}; }
|
||||
};
|
||||
|
||||
struct SquaredLogError {
|
||||
@@ -61,6 +64,8 @@ struct SquaredLogError {
|
||||
static const char* DefaultEvalMetric() { return "rmsle"; }
|
||||
|
||||
static const char* Name() { return "reg:squaredlogerror"; }
|
||||
|
||||
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||
};
|
||||
|
||||
// logistic loss for probability regression task
|
||||
@@ -96,6 +101,8 @@ struct LogisticRegression {
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
|
||||
static const char* Name() { return "reg:logistic"; }
|
||||
|
||||
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||
};
|
||||
|
||||
struct PseudoHuberError {
|
||||
@@ -127,12 +134,14 @@ struct PseudoHuberError {
|
||||
static const char* Name() {
|
||||
return "reg:pseudohubererror";
|
||||
}
|
||||
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||
};
|
||||
|
||||
// logistic loss for binary classification task
|
||||
struct LogisticClassification : public LogisticRegression {
|
||||
static const char* DefaultEvalMetric() { return "logloss"; }
|
||||
static const char* Name() { return "binary:logistic"; }
|
||||
static ObjInfo Info() { return {ObjInfo::kBinary, false}; }
|
||||
};
|
||||
|
||||
// logistic loss, but predict un-transformed margin
|
||||
@@ -168,6 +177,8 @@ struct LogisticRaw : public LogisticRegression {
|
||||
static const char* DefaultEvalMetric() { return "logloss"; }
|
||||
|
||||
static const char* Name() { return "binary:logitraw"; }
|
||||
|
||||
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||
};
|
||||
|
||||
} // namespace obj
|
||||
|
||||
@@ -52,6 +52,10 @@ class RegLossObj : public ObjFunction {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
|
||||
struct ObjInfo Task() const override {
|
||||
return Loss::Info();
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info, int,
|
||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||
@@ -207,6 +211,10 @@ class PoissonRegression : public ObjFunction {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
}
|
||||
|
||||
struct ObjInfo Task() const override {
|
||||
return {ObjInfo::kRegression, false};
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info, int,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
@@ -298,6 +306,10 @@ class CoxRegression : public ObjFunction {
|
||||
void Configure(
|
||||
const std::vector<std::pair<std::string, std::string> >&) override {}
|
||||
|
||||
struct ObjInfo Task() const override {
|
||||
return {ObjInfo::kRegression, false};
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info, int,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
@@ -395,6 +407,10 @@ class GammaRegression : public ObjFunction {
|
||||
void Configure(
|
||||
const std::vector<std::pair<std::string, std::string> >&) override {}
|
||||
|
||||
struct ObjInfo Task() const override {
|
||||
return {ObjInfo::kRegression, false};
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info, int,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
@@ -491,6 +507,10 @@ class TweedieRegression : public ObjFunction {
|
||||
metric_ = os.str();
|
||||
}
|
||||
|
||||
struct ObjInfo Task() const override {
|
||||
return {ObjInfo::kRegression, false};
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info, int,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
|
||||
Reference in New Issue
Block a user