More tests for column split and vertical federated learning (#8985)
Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning. Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`. Some refactoring of the testing fixtures.
This commit is contained in:
@@ -1,12 +1,9 @@
|
||||
/*!
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#include <dmlc/parameter.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
|
||||
#include "../../../plugin/federated/federated_server.h"
|
||||
@@ -17,49 +14,40 @@
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
class FederatedDataTest : public BaseFederatedTest {
|
||||
public:
|
||||
void VerifyLoadUri(int rank) {
|
||||
InitCommunicator(rank);
|
||||
class FederatedDataTest : public BaseFederatedTest {};
|
||||
|
||||
size_t constexpr kRows{16};
|
||||
size_t const kCols = 8 + rank;
|
||||
void VerifyLoadUri() {
|
||||
auto const rank = collective::GetRank();
|
||||
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string path = tmpdir.path + "/small" + std::to_string(rank) + ".csv";
|
||||
CreateTestCSV(path, kRows, kCols);
|
||||
size_t constexpr kRows{16};
|
||||
size_t const kCols = 8 + rank;
|
||||
|
||||
std::unique_ptr<DMatrix> dmat;
|
||||
std::string uri = path + "?format=csv";
|
||||
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string path = tmpdir.path + "/small" + std::to_string(rank) + ".csv";
|
||||
CreateTestCSV(path, kRows, kCols);
|
||||
|
||||
ASSERT_EQ(dmat->Info().num_col_, 8 * kWorldSize + 3);
|
||||
ASSERT_EQ(dmat->Info().num_row_, kRows);
|
||||
std::unique_ptr<DMatrix> dmat;
|
||||
std::string uri = path + "?format=csv";
|
||||
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kCol));
|
||||
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
auto entries = page.GetView().data;
|
||||
auto index = 0;
|
||||
int offsets[] = {0, 8, 17};
|
||||
int offset = offsets[rank];
|
||||
for (auto row = 0; row < kRows; row++) {
|
||||
for (auto col = 0; col < kCols; col++) {
|
||||
EXPECT_EQ(entries[index].index, col + offset);
|
||||
index++;
|
||||
}
|
||||
ASSERT_EQ(dmat->Info().num_col_, 8 * collective::GetWorldSize() + 3);
|
||||
ASSERT_EQ(dmat->Info().num_row_, kRows);
|
||||
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
auto entries = page.GetView().data;
|
||||
auto index = 0;
|
||||
int offsets[] = {0, 8, 17};
|
||||
int offset = offsets[rank];
|
||||
for (auto row = 0; row < kRows; row++) {
|
||||
for (auto col = 0; col < kCols; col++) {
|
||||
EXPECT_EQ(entries[index].index, col + offset);
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
xgboost::collective::Finalize();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(FederatedDataTest, LoadUri) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedDataTest_LoadUri_Test::VerifyLoadUri, this, rank);
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedDataTest, LoadUri) {
|
||||
RunWithFederatedCommunicator(kWorldSize, server_address_, &VerifyLoadUri);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user