Fix tweedie metric string. (#4543)

This commit is contained in:
Jiaming Yuan 2019-06-09 09:52:29 +08:00 committed by GitHub
parent 59ae42a179
commit da21ac0cc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 6 deletions

View File

@ -437,6 +437,9 @@ class TweedieRegression : public ObjFunction {
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
std::ostringstream os;
os << "tweedie-nloglik@" << param_.tweedie_variance_power;
metric_ = os.str();
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
@ -499,13 +502,11 @@ class TweedieRegression : public ObjFunction {
}
const char* DefaultEvalMetric() const override {
std::ostringstream os;
os << "tweedie-nloglik@" << param_.tweedie_variance_power;
std::string metric = os.str();
return metric.c_str();
return metric_.c_str();
}
private:
std::string metric_;
TweedieRegressionParam param_;
HostDeviceVector<int> label_correct_;
};

View File

@ -211,7 +211,7 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) {
{}, // Empty weight.
{ 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f},
{0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f});
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"tweedie-nloglik@1.1"});
delete obj;
}

View File

@ -5,7 +5,7 @@
#include <vector>
int main(int argc, char ** argv) {
std::vector<std::pair<std::string, std::string>> args {{"verbosity", "3"}};
std::vector<std::pair<std::string, std::string>> args {{"verbosity", "2"}};
xgboost::ConsoleLogger::Configure(args.begin(), args.end());
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";