diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index b454cf065..871b90092 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -136,3 +136,39 @@ TEST(Objective, GammaRegressionBasic) { EXPECT_NEAR(preds[i], out_preds[i], 0.01); } } + +TEST(Objective, TweedieRegressionGPair) { + xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie"); + std::vector > args; + args.push_back(std::make_pair("tweedie_variance_power", "1.1")); + obj->Configure(args); + CheckObjFunction(obj, + { 0, 0.1, 0.9, 1, 0, 0.1, 0.9, 1}, + { 0, 0, 0, 0, 1, 1, 1, 1}, + { 1, 1, 1, 1, 1, 1, 1, 1}, + { 1, 1.09, 2.24, 2.45, 0, 0.10, 1.33, 1.55}, + {0.89, 0.98, 2.02, 2.21, 1, 1.08, 2.11, 2.30}); +} + +TEST(Objective, TweedieRegressionBasic) { + xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie"); + std::vector > args; + obj->Configure(args); + + // test label validation + EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {-1}, {1}, {0}, {0})) + << "Expected error when label < 0 for TweedieRegression"; + + // test ProbToMargin + EXPECT_NEAR(obj->ProbToMargin(0.1), 0.10, 0.01); + EXPECT_NEAR(obj->ProbToMargin(0.5), 0.5, 0.01); + EXPECT_NEAR(obj->ProbToMargin(0.9), 0.89, 0.01); + + // test PredTransform + std::vector preds = {0, 0.1, 0.5, 0.9, 1}; + std::vector out_preds = {1, 1.10, 1.64, 2.45, 2.71}; + obj->PredTransform(&preds); + for (int i = 0; i < preds.size(); ++i) { + EXPECT_NEAR(preds[i], out_preds[i], 0.01); + } +}