More tests for column split and vertical federated learning (#8985)

Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning.

Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`.

Some refactoring of the testing fixtures.
This commit is contained in:
Rong Ou
2023-03-28 01:40:26 -07:00
committed by GitHub
parent 401ce5cf5e
commit ff26cd3212
18 changed files with 212 additions and 94 deletions

View File

@@ -460,10 +460,6 @@ class InitBaseScore : public ::testing::Test {
void SetUp() override { Xy_ = RandomDataGenerator{10, Cols(), 0}.GenerateDMatrix(true); }
static float GetBaseScore(Json const &config) {
return std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
}
public:
void TestUpdateConfig() {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
@@ -611,4 +607,32 @@ TEST_F(InitBaseScore, InitAfterLoad) { this->TestInitAfterLoad(); }
TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); }
TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); }
void TestColumnSplitBaseScore(std::shared_ptr<DMatrix> Xy_, float expected_base_score) {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::shared_ptr<DMatrix> sliced{Xy_->SliceCol(world_size, rank)};
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
learner->SetParam("tree_method", "approx");
learner->SetParam("objective", "binary:logistic");
learner->UpdateOneIter(0, sliced);
Json config{Object{}};
learner->SaveConfig(&config);
auto base_score = GetBaseScore(config);
ASSERT_EQ(base_score, expected_base_score);
}
TEST_F(InitBaseScore, ColumnSplit) {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("tree_method", "approx");
learner->SetParam("objective", "binary:logistic");
learner->UpdateOneIter(0, Xy_);
Json config{Object{}};
learner->SaveConfig(&config);
auto base_score = GetBaseScore(config);
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
auto constexpr kWorldSize{3};
RunWithInMemoryCommunicator(kWorldSize, &TestColumnSplitBaseScore, Xy_, base_score);
}
} // namespace xgboost