Fix tweedie metric string. (#4543)
This commit is contained in:
parent
59ae42a179
commit
da21ac0cc2
@ -437,6 +437,9 @@ class TweedieRegression : public ObjFunction {
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
std::ostringstream os;
|
||||
os << "tweedie-nloglik@" << param_.tweedie_variance_power;
|
||||
metric_ = os.str();
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
@ -499,13 +502,11 @@ class TweedieRegression : public ObjFunction {
|
||||
}
|
||||
|
||||
const char* DefaultEvalMetric() const override {
|
||||
std::ostringstream os;
|
||||
os << "tweedie-nloglik@" << param_.tweedie_variance_power;
|
||||
std::string metric = os.str();
|
||||
return metric.c_str();
|
||||
return metric_.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
std::string metric_;
|
||||
TweedieRegressionParam param_;
|
||||
HostDeviceVector<int> label_correct_;
|
||||
};
|
||||
|
||||
@ -211,7 +211,7 @@ TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) {
|
||||
{}, // Empty weight.
|
||||
{ 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f},
|
||||
{0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f});
|
||||
|
||||
ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"tweedie-nloglik@1.1"});
|
||||
delete obj;
|
||||
}
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
#include <vector>
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
std::vector<std::pair<std::string, std::string>> args {{"verbosity", "3"}};
|
||||
std::vector<std::pair<std::string, std::string>> args {{"verbosity", "2"}};
|
||||
xgboost::ConsoleLogger::Configure(args.begin(), args.end());
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
testing::FLAGS_gtest_death_test_style = "threadsafe";
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user