Initial support for multi-target tree. (#8616)
* Implement multi-target for hist. - Add new hist tree builder. - Move data fetchers for tests. - Dispatch function calls in gbm base on the tree type.
This commit is contained in:
@@ -412,7 +412,7 @@ std::pair<Json, Json> TestModelSlice(std::string booster) {
|
||||
j++;
|
||||
}
|
||||
|
||||
// CHECK sliced model doesn't have dependency on old one
|
||||
// CHECK sliced model doesn't have dependency on the old one
|
||||
learner.reset();
|
||||
CHECK_EQ(sliced->GetNumFeature(), kCols);
|
||||
|
||||
|
||||
@@ -473,7 +473,7 @@ inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint
|
||||
int32_t device = Context::kCpuId) {
|
||||
size_t shape[1]{1};
|
||||
LearnerModelParam mparam(n_features, linalg::Tensor<float, 1>{{base_score}, shape, device},
|
||||
n_groups, 1, MultiStrategy::kComposite);
|
||||
n_groups, 1, MultiStrategy::kOneOutputPerTree);
|
||||
return mparam;
|
||||
}
|
||||
|
||||
|
||||
@@ -428,7 +428,7 @@ void TestVectorLeafPrediction(Context const *ctx) {
|
||||
|
||||
LearnerModelParam mparam{static_cast<bst_feature_t>(kCols),
|
||||
linalg::Vector<float>{{0.5}, {1}, Context::kCpuId}, 1, 3,
|
||||
MultiStrategy::kMonolithic};
|
||||
MultiStrategy::kMultiOutputTree};
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.emplace_back(new RegTree{mparam.LeafLength(), mparam.num_feature});
|
||||
|
||||
@@ -124,11 +124,11 @@ TEST(MultiStrategy, Configure) {
|
||||
auto p_fmat = RandomDataGenerator{12ul, 3ul, 0.0}.GenerateDMatrix();
|
||||
p_fmat->Info().labels.Reshape(p_fmat->Info().num_row_, 2);
|
||||
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
|
||||
learner->SetParams(Args{{"multi_strategy", "monolithic"}, {"num_target", "2"}});
|
||||
learner->SetParams(Args{{"multi_strategy", "multi_output_tree"}, {"num_target", "2"}});
|
||||
learner->Configure();
|
||||
ASSERT_EQ(learner->Groups(), 2);
|
||||
|
||||
learner->SetParams(Args{{"multi_strategy", "monolithic"}, {"num_target", "0"}});
|
||||
learner->SetParams(Args{{"multi_strategy", "multi_output_tree"}, {"num_target", "0"}});
|
||||
ASSERT_THROW({ learner->Configure(); }, dmlc::Error);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user