From 12e34f32e2b3022003853ae7a9f427854b91d0a7 Mon Sep 17 00:00:00 2001 From: pdesahb <38980421+pdesahb@users.noreply.github.com> Date: Thu, 28 Jun 2018 17:43:05 +0200 Subject: [PATCH] Fix tweedie handling of base_score (#3295) * fix tweedie margin calculations * add entry to contributors --- CONTRIBUTORS.md | 1 + src/objective/regression_obj.cc | 5 +++++ tests/cpp/objective/test_regression_obj.cc | 6 +++--- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 433c596bd..6f75b56ce 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -75,3 +75,4 @@ List of Contributors * [Andrew Hannigan](https://github.com/andrewhannigan) * [Andy Adinets](https://github.com/canonizer) * [Henry Gouk](https://github.com/henrygouk) +* [Pierre de Sahb](https://github.com/pdesahb) diff --git a/src/objective/regression_obj.cc b/src/objective/regression_obj.cc index 98d7b9147..6b793d59c 100644 --- a/src/objective/regression_obj.cc +++ b/src/objective/regression_obj.cc @@ -394,6 +394,11 @@ class TweedieRegression : public ObjFunction { preds[j] = std::exp(preds[j]); } } + + bst_float ProbToMargin(bst_float base_score) const override { + return std::log(base_score); + } + const char* DefaultEvalMetric() const override { std::ostringstream os; os << "tweedie-nloglik@" << param_.tweedie_variance_power; diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 41bcaadee..27820f9df 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -163,9 +163,9 @@ TEST(Objective, TweedieRegressionBasic) { << "Expected error when label < 0 for TweedieRegression"; // test ProbToMargin - EXPECT_NEAR(obj->ProbToMargin(0.1f), 0.10f, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.5f), 0.5f, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.9f), 0.89f, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.30f, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.5f), -0.69f, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.9f), -0.10f, 0.01f); // test PredTransform xgboost::HostDeviceVector io_preds = {0, 0.1f, 0.5f, 0.9f, 1};