[breaking] Change DMatrix construction to be distributed (#9623)

* Change column-split DMatrix construction to be distributed

* remove splitting code for row split
This commit is contained in:
Rong Ou
2023-10-10 08:35:57 -07:00
committed by GitHub
parent b14e535e78
commit 0ecb4de963
7 changed files with 36 additions and 65 deletions

View File

@@ -729,7 +729,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
void MetaInfo::SynchronizeNumberOfColumns() {
if (IsVerticalFederated()) {
if (IsColumnSplit()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
@@ -850,14 +850,6 @@ DMatrix* TryLoadBinary(std::string fname, bool silent) {
} // namespace
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) {
auto need_split = false;
if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
} else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
need_split = true;
}
std::string fname, cache_file;
auto dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) {
@@ -865,24 +857,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
fname = uri.substr(0, dlm_pos);
CHECK_EQ(cache_file.find('#'), std::string::npos)
<< "Only one `#` is allowed in file path for cache file specification.";
if (need_split && data_split_mode == DataSplitMode::kRow) {
std::ostringstream os;
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
for (size_t i = 0; i < cache_shards.size(); ++i) {
size_t pos = cache_shards[i].rfind('.');
if (pos == std::string::npos) {
os << cache_shards[i] << ".r" << collective::GetRank() << "-"
<< collective::GetWorldSize();
} else {
os << cache_shards[i].substr(0, pos) << ".r" << collective::GetRank() << "-"
<< collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length());
}
if (i + 1 != cache_shards.size()) {
os << ':';
}
}
cache_file = os.str();
}
} else {
fname = uri;
}
@@ -894,19 +868,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
}
int partid = 0, npart = 1;
if (need_split && data_split_mode == DataSplitMode::kRow) {
partid = collective::GetRank();
npart = collective::GetWorldSize();
} else {
// test option to load in part
npart = 1;
}
if (npart != 1) {
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
}
DMatrix* dmat{nullptr};
DMatrix* dmat{};
if (cache_file.empty()) {
fname = data::ValidateFileFormat(fname);
@@ -916,6 +878,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
cache_file, data_split_mode);
} else {
CHECK(data_split_mode != DataSplitMode::kCol)
<< "Column-wise data split is not supported for external memory.";
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart)};
dmat = new data::SparsePageDMatrix{&iter,
iter.Proxy(),
@@ -926,17 +890,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
cache_file};
}
if (need_split && data_split_mode == DataSplitMode::kCol) {
if (!cache_file.empty()) {
LOG(FATAL) << "Column-wise data split is not support for external memory.";
}
LOG(CONSOLE) << "Splitting data by column";
auto* sliced = dmat->SliceCol(npart, partid);
delete dmat;
return sliced;
} else {
return dmat;
}
return dmat;
}
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,

View File

@@ -75,11 +75,11 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
}
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
if (info_.IsVerticalFederated()) {
if (info_.IsColumnSplit()) {
std::vector<uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = info_.num_col_;
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0);
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0ul);
if (offset == 0) {
return;
}

View File

@@ -64,9 +64,10 @@ class SimpleDMatrix : public DMatrix {
/**
* \brief Reindex the features based on a global view.
*
* In some cases (e.g. vertical federated learning), features are loaded locally with indices
* starting from 0. However, all the algorithms assume the features are globally indexed, so we
* reindex the features based on the offset needed to obtain the global view.
* In some cases (e.g. column-wise data split and vertical federated learning), features are
* loaded locally with indices starting from 0. However, all the algorithms assume the features
* are globally indexed, so we reindex the features based on the offset needed to obtain the
* global view.
*/
void ReindexFeatures(Context const* ctx);