Support vertical federated learning (#8932)

This commit is contained in:
Rong Ou
2023-03-21 23:25:26 -07:00
committed by GitHub
parent 8dc1e4b3ea
commit b240f055d3
23 changed files with 371 additions and 249 deletions

View File

@@ -703,6 +703,14 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
}
void MetaInfo::SynchronizeNumberOfColumns() {
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
}
}
void MetaInfo::Validate(std::int32_t device) const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
@@ -870,7 +878,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
data::FileAdapter adapter(parser.get());
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
cache_file);
cache_file, data_split_mode);
} else {
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
file_format};
@@ -906,11 +914,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
LOG(FATAL) << "Encountered parser error:\n" << e.what();
}
/* sync up number of features after matrix loaded.
* partitioned data will fail the train/val validation check
* since partitioned data not knowing the real number of features. */
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
if (need_split && data_split_mode == DataSplitMode::kCol) {
if (!cache_file.empty()) {
LOG(FATAL) << "Column-wise data split is not support for external memory.";
@@ -920,7 +923,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
delete dmat;
return sliced;
} else {
dmat->Info().data_split_mode = data_split_mode;
return dmat;
}
}
@@ -957,39 +959,49 @@ template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
XGDMatrixCallbackNext *next, float missing, int32_t n_threads, std::string);
template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&) {
return new data::SimpleDMatrix(adapter, missing, nthread);
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
DataSplitMode data_split_mode) {
return new data::SimpleDMatrix(adapter, missing, nthread, data_split_mode);
}
template DMatrix* DMatrix::Create<data::DenseAdapter>(data::DenseAdapter* adapter, float missing,
std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::ArrayAdapter>(data::ArrayAdapter* adapter, float missing,
std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::CSRAdapter>(data::CSRAdapter* adapter, float missing,
std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::CSCAdapter>(data::CSCAdapter* adapter, float missing,
std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::DataTableAdapter>(data::DataTableAdapter* adapter,
float missing, std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::FileAdapter>(data::FileAdapter* adapter, float missing,
std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::CSRArrayAdapter>(data::CSRArrayAdapter* adapter,
float missing, std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::CSCArrayAdapter>(data::CSCArrayAdapter* adapter,
float missing, std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create(
data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
float missing, int nthread, const std::string& cache_prefix);
float missing, int nthread, const std::string& cache_prefix, DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::RecordBatchesIterAdapter>(
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&);
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&,
DataSplitMode data_split_mode);
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
SparsePage transpose;
@@ -1051,6 +1063,13 @@ void SparsePage::SortIndices(int32_t n_threads) {
});
}
void SparsePage::Reindex(uint64_t feature_offset, int32_t n_threads) {
auto& h_data = this->data.HostVector();
common::ParallelFor(h_data.size(), n_threads, [&](auto i) {
h_data[i].index += feature_offset;
});
}
void SparsePage::SortRows(int32_t n_threads) {
auto& h_offset = this->offset.HostVector();
auto& h_data = this->data.HostVector();