Calculate base_score based on input labels for mae. (#8107)

Fit an intercept as base score for abs loss.
This commit is contained in:
Jiaming Yuan
2022-09-20 20:53:54 +08:00
committed by GitHub
parent 4f42aa5f12
commit fffb1fca52
42 changed files with 999 additions and 343 deletions

View File

@@ -21,14 +21,11 @@ TEST(CpuPredictor, Basic) {
size_t constexpr kRows = 5;
size_t constexpr kCols = 5;
LearnerModelParam param;
param.num_feature = kCols;
param.base_score = 0.0;
param.num_output_group = 1;
LearnerModelParam mparam{MakeMP(kCols, .0, 1)};
GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model = CreateTestModel(&param, &ctx);
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
@@ -104,14 +101,11 @@ TEST(CpuPredictor, ExternalMemory) {
std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
LearnerModelParam param;
param.base_score = 0;
param.num_feature = dmat->Info().num_col_;
param.num_output_group = 1;
LearnerModelParam mparam{MakeMP(dmat->Info().num_col_, .0, 1)};
GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model = CreateTestModel(&param, &ctx);
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
// Test predict batch
PredictionCacheEntry out_predictions;
@@ -201,16 +195,11 @@ TEST(CpuPredictor, InplacePredict) {
void TestUpdatePredictionCache(bool use_subsampling) {
size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
LearnerModelParam mparam;
mparam.num_feature = kCols;
mparam.num_output_group = kClasses;
mparam.base_score = 0;
GenericParameter gparam;
gparam.Init(Args{});
LearnerModelParam mparam{MakeMP(kCols, .0, kClasses)};
Context ctx;
std::unique_ptr<gbm::GBTree> gbm;
gbm.reset(static_cast<gbm::GBTree*>(GradientBooster::Create("gbtree", &gparam, &mparam)));
gbm.reset(static_cast<gbm::GBTree*>(GradientBooster::Create("gbtree", &ctx, &mparam)));
std::map<std::string, std::string> cfg;
cfg["tree_method"] = "hist";
cfg["predictor"] = "cpu_predictor";