Support vertical federated learning (#8932)
This commit is contained in:
@@ -703,6 +703,14 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::SynchronizeNumberOfColumns() {
|
||||
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
|
||||
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
|
||||
} else {
|
||||
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::Validate(std::int32_t device) const {
|
||||
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
||||
@@ -870,7 +878,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
|
||||
data::FileAdapter adapter(parser.get());
|
||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
|
||||
cache_file);
|
||||
cache_file, data_split_mode);
|
||||
} else {
|
||||
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
|
||||
file_format};
|
||||
@@ -906,11 +914,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
LOG(FATAL) << "Encountered parser error:\n" << e.what();
|
||||
}
|
||||
|
||||
/* sync up number of features after matrix loaded.
|
||||
* partitioned data will fail the train/val validation check
|
||||
* since partitioned data not knowing the real number of features. */
|
||||
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
|
||||
|
||||
if (need_split && data_split_mode == DataSplitMode::kCol) {
|
||||
if (!cache_file.empty()) {
|
||||
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
||||
@@ -920,7 +923,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
delete dmat;
|
||||
return sliced;
|
||||
} else {
|
||||
dmat->Info().data_split_mode = data_split_mode;
|
||||
return dmat;
|
||||
}
|
||||
}
|
||||
@@ -957,39 +959,49 @@ template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
|
||||
XGDMatrixCallbackNext *next, float missing, int32_t n_threads, std::string);
|
||||
|
||||
template <typename AdapterT>
|
||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&) {
|
||||
return new data::SimpleDMatrix(adapter, missing, nthread);
|
||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
|
||||
DataSplitMode data_split_mode) {
|
||||
return new data::SimpleDMatrix(adapter, missing, nthread, data_split_mode);
|
||||
}
|
||||
|
||||
template DMatrix* DMatrix::Create<data::DenseAdapter>(data::DenseAdapter* adapter, float missing,
|
||||
std::int32_t nthread,
|
||||
const std::string& cache_prefix);
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::ArrayAdapter>(data::ArrayAdapter* adapter, float missing,
|
||||
std::int32_t nthread,
|
||||
const std::string& cache_prefix);
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::CSRAdapter>(data::CSRAdapter* adapter, float missing,
|
||||
std::int32_t nthread,
|
||||
const std::string& cache_prefix);
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::CSCAdapter>(data::CSCAdapter* adapter, float missing,
|
||||
std::int32_t nthread,
|
||||
const std::string& cache_prefix);
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::DataTableAdapter>(data::DataTableAdapter* adapter,
|
||||
float missing, std::int32_t nthread,
|
||||
const std::string& cache_prefix);
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::FileAdapter>(data::FileAdapter* adapter, float missing,
|
||||
std::int32_t nthread,
|
||||
const std::string& cache_prefix);
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::CSRArrayAdapter>(data::CSRArrayAdapter* adapter,
|
||||
float missing, std::int32_t nthread,
|
||||
const std::string& cache_prefix);
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::CSCArrayAdapter>(data::CSCArrayAdapter* adapter,
|
||||
float missing, std::int32_t nthread,
|
||||
const std::string& cache_prefix);
|
||||
const std::string& cache_prefix,
|
||||
DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create(
|
||||
data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
|
||||
float missing, int nthread, const std::string& cache_prefix);
|
||||
float missing, int nthread, const std::string& cache_prefix, DataSplitMode data_split_mode);
|
||||
template DMatrix* DMatrix::Create<data::RecordBatchesIterAdapter>(
|
||||
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&);
|
||||
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&,
|
||||
DataSplitMode data_split_mode);
|
||||
|
||||
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
||||
SparsePage transpose;
|
||||
@@ -1051,6 +1063,13 @@ void SparsePage::SortIndices(int32_t n_threads) {
|
||||
});
|
||||
}
|
||||
|
||||
void SparsePage::Reindex(uint64_t feature_offset, int32_t n_threads) {
|
||||
auto& h_data = this->data.HostVector();
|
||||
common::ParallelFor(h_data.size(), n_threads, [&](auto i) {
|
||||
h_data[i].index += feature_offset;
|
||||
});
|
||||
}
|
||||
|
||||
void SparsePage::SortRows(int32_t n_threads) {
|
||||
auto& h_offset = this->offset.HostVector();
|
||||
auto& h_data = this->data.HostVector();
|
||||
|
||||
Reference in New Issue
Block a user