Calculate base_score based on input labels for mae. (#8107)
Fit an intercept as base score for abs loss.
This commit is contained in:
@@ -210,11 +210,7 @@ void TestCategoricalPrediction(std::string name) {
|
||||
size_t constexpr kCols = 10;
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
LearnerModelParam param;
|
||||
param.num_feature = kCols;
|
||||
param.num_output_group = 1;
|
||||
param.base_score = 0.5;
|
||||
|
||||
LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
|
||||
uint32_t split_ind = 3;
|
||||
bst_cat_t split_cat = 4;
|
||||
float left_weight = 1.3f;
|
||||
@@ -222,7 +218,7 @@ void TestCategoricalPrediction(std::string name) {
|
||||
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model(¶m, &ctx);
|
||||
gbm::GBTreeModel model(&mparam, &ctx);
|
||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||
|
||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
@@ -237,27 +233,24 @@ void TestCategoricalPrediction(std::string name) {
|
||||
|
||||
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||
auto score = mparam.BaseScore(Context::kCpuId)(0);
|
||||
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
||||
right_weight + param.base_score); // go to right for matching cat
|
||||
right_weight + score); // go to right for matching cat
|
||||
|
||||
row[split_ind] = split_cat + 1;
|
||||
m = GetDMatrixFromData(row, 1, kCols);
|
||||
out_predictions.version = 0;
|
||||
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
||||
left_weight + param.base_score);
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score);
|
||||
}
|
||||
|
||||
void TestCategoricalPredictLeaf(StringView name) {
|
||||
size_t constexpr kCols = 10;
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
LearnerModelParam param;
|
||||
param.num_feature = kCols;
|
||||
param.num_output_group = 1;
|
||||
param.base_score = 0.5;
|
||||
LearnerModelParam mparam{MakeMP(kCols, .5, 1)};
|
||||
|
||||
uint32_t split_ind = 3;
|
||||
bst_cat_t split_cat = 4;
|
||||
@@ -267,7 +260,7 @@ void TestCategoricalPredictLeaf(StringView name) {
|
||||
GenericParameter ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
|
||||
gbm::GBTreeModel model(¶m, &ctx);
|
||||
gbm::GBTreeModel model(&mparam, &ctx);
|
||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||
|
||||
ctx.gpu_id = 0;
|
||||
|
||||
Reference in New Issue
Block a user