Pass infomation about objective to tree methods. (#7385)
* Define the `ObjInfo` and pass it down to every tree updater.
This commit is contained in:
@@ -7,6 +7,8 @@
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <algorithm>
|
||||
|
||||
#include "xgboost/task.h"
|
||||
#include "../common/math.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -36,6 +38,7 @@ struct LinearSquareLoss {
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
|
||||
static const char* Name() { return "reg:squarederror"; }
|
||||
static ObjInfo Info() { return {ObjInfo::kRegression, true}; }
|
||||
};
|
||||
|
||||
struct SquaredLogError {
|
||||
@@ -61,6 +64,8 @@ struct SquaredLogError {
|
||||
static const char* DefaultEvalMetric() { return "rmsle"; }
|
||||
|
||||
static const char* Name() { return "reg:squaredlogerror"; }
|
||||
|
||||
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||
};
|
||||
|
||||
// logistic loss for probability regression task
|
||||
@@ -96,6 +101,8 @@ struct LogisticRegression {
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
|
||||
static const char* Name() { return "reg:logistic"; }
|
||||
|
||||
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||
};
|
||||
|
||||
struct PseudoHuberError {
|
||||
@@ -127,12 +134,14 @@ struct PseudoHuberError {
|
||||
static const char* Name() {
|
||||
return "reg:pseudohubererror";
|
||||
}
|
||||
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||
};
|
||||
|
||||
// logistic loss for binary classification task
|
||||
struct LogisticClassification : public LogisticRegression {
|
||||
static const char* DefaultEvalMetric() { return "logloss"; }
|
||||
static const char* Name() { return "binary:logistic"; }
|
||||
static ObjInfo Info() { return {ObjInfo::kBinary, false}; }
|
||||
};
|
||||
|
||||
// logistic loss, but predict un-transformed margin
|
||||
@@ -168,6 +177,8 @@ struct LogisticRaw : public LogisticRegression {
|
||||
static const char* DefaultEvalMetric() { return "logloss"; }
|
||||
|
||||
static const char* Name() { return "binary:logitraw"; }
|
||||
|
||||
static ObjInfo Info() { return {ObjInfo::kRegression, false}; }
|
||||
};
|
||||
|
||||
} // namespace obj
|
||||
|
||||
Reference in New Issue
Block a user