From da21ac0cc22dc283fc4f551da7ebdb66d61fad20 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 9 Jun 2019 09:52:29 +0800 Subject: [PATCH] Fix tweedie metric string. (#4543) --- src/objective/regression_obj.cu | 9 +++++---- tests/cpp/objective/test_regression_obj.cc | 2 +- tests/cpp/test_main.cc | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 086603840..545b95564 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -437,6 +437,9 @@ class TweedieRegression : public ObjFunction { // declare functions void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); + std::ostringstream os; + os << "tweedie-nloglik@" << param_.tweedie_variance_power; + metric_ = os.str(); } void GetGradient(const HostDeviceVector& preds, @@ -499,13 +502,11 @@ class TweedieRegression : public ObjFunction { } const char* DefaultEvalMetric() const override { - std::ostringstream os; - os << "tweedie-nloglik@" << param_.tweedie_variance_power; - std::string metric = os.str(); - return metric.c_str(); + return metric_.c_str(); } private: + std::string metric_; TweedieRegressionParam param_; HostDeviceVector label_correct_; }; diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 1f4cec2dd..b17a6e63c 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -211,7 +211,7 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) { {}, // Empty weight. { 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f}, {0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f}); - + ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"tweedie-nloglik@1.1"}); delete obj; } diff --git a/tests/cpp/test_main.cc b/tests/cpp/test_main.cc index 980fa0cda..6e9a1fa2a 100644 --- a/tests/cpp/test_main.cc +++ b/tests/cpp/test_main.cc @@ -5,7 +5,7 @@ #include int main(int argc, char ** argv) { - std::vector> args {{"verbosity", "3"}}; + std::vector> args {{"verbosity", "2"}}; xgboost::ConsoleLogger::Configure(args.begin(), args.end()); testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";