[breaking] Change DMatrix construction to be distributed (#9623)
* Change column-split DMatrix construction to be distributed * remove splitting code for row split
This commit is contained in:
parent
b14e535e78
commit
0ecb4de963
@ -559,8 +559,7 @@ class DMatrix {
|
||||
*
|
||||
* \param uri The URI of input.
|
||||
* \param silent Whether print information during loading.
|
||||
* \param data_split_mode In distributed mode, split the input according this mode; otherwise,
|
||||
* it's just an indicator on how the input was split beforehand.
|
||||
* \param data_split_mode Indicate how the data was split beforehand.
|
||||
* \return The created DMatrix.
|
||||
*/
|
||||
static DMatrix* Load(const std::string& uri, bool silent = true,
|
||||
|
||||
@ -729,7 +729,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
}
|
||||
|
||||
void MetaInfo::SynchronizeNumberOfColumns() {
|
||||
if (IsVerticalFederated()) {
|
||||
if (IsColumnSplit()) {
|
||||
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
|
||||
} else {
|
||||
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
|
||||
@ -850,14 +850,6 @@ DMatrix* TryLoadBinary(std::string fname, bool silent) {
|
||||
} // namespace
|
||||
|
||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) {
|
||||
auto need_split = false;
|
||||
if (collective::IsFederated()) {
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
} else if (collective::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
|
||||
need_split = true;
|
||||
}
|
||||
|
||||
std::string fname, cache_file;
|
||||
auto dlm_pos = uri.find('#');
|
||||
if (dlm_pos != std::string::npos) {
|
||||
@ -865,24 +857,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
fname = uri.substr(0, dlm_pos);
|
||||
CHECK_EQ(cache_file.find('#'), std::string::npos)
|
||||
<< "Only one `#` is allowed in file path for cache file specification.";
|
||||
if (need_split && data_split_mode == DataSplitMode::kRow) {
|
||||
std::ostringstream os;
|
||||
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
|
||||
for (size_t i = 0; i < cache_shards.size(); ++i) {
|
||||
size_t pos = cache_shards[i].rfind('.');
|
||||
if (pos == std::string::npos) {
|
||||
os << cache_shards[i] << ".r" << collective::GetRank() << "-"
|
||||
<< collective::GetWorldSize();
|
||||
} else {
|
||||
os << cache_shards[i].substr(0, pos) << ".r" << collective::GetRank() << "-"
|
||||
<< collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length());
|
||||
}
|
||||
if (i + 1 != cache_shards.size()) {
|
||||
os << ':';
|
||||
}
|
||||
}
|
||||
cache_file = os.str();
|
||||
}
|
||||
} else {
|
||||
fname = uri;
|
||||
}
|
||||
@ -894,19 +868,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
}
|
||||
|
||||
int partid = 0, npart = 1;
|
||||
if (need_split && data_split_mode == DataSplitMode::kRow) {
|
||||
partid = collective::GetRank();
|
||||
npart = collective::GetWorldSize();
|
||||
} else {
|
||||
// test option to load in part
|
||||
npart = 1;
|
||||
}
|
||||
|
||||
if (npart != 1) {
|
||||
LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts";
|
||||
}
|
||||
|
||||
DMatrix* dmat{nullptr};
|
||||
DMatrix* dmat{};
|
||||
|
||||
if (cache_file.empty()) {
|
||||
fname = data::ValidateFileFormat(fname);
|
||||
@ -916,6 +878,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
|
||||
cache_file, data_split_mode);
|
||||
} else {
|
||||
CHECK(data_split_mode != DataSplitMode::kCol)
|
||||
<< "Column-wise data split is not supported for external memory.";
|
||||
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart)};
|
||||
dmat = new data::SparsePageDMatrix{&iter,
|
||||
iter.Proxy(),
|
||||
@ -926,17 +890,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
cache_file};
|
||||
}
|
||||
|
||||
if (need_split && data_split_mode == DataSplitMode::kCol) {
|
||||
if (!cache_file.empty()) {
|
||||
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
||||
}
|
||||
LOG(CONSOLE) << "Splitting data by column";
|
||||
auto* sliced = dmat->SliceCol(npart, partid);
|
||||
delete dmat;
|
||||
return sliced;
|
||||
} else {
|
||||
return dmat;
|
||||
}
|
||||
return dmat;
|
||||
}
|
||||
|
||||
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
|
||||
|
||||
@ -75,11 +75,11 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
}
|
||||
|
||||
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
|
||||
if (info_.IsVerticalFederated()) {
|
||||
if (info_.IsColumnSplit()) {
|
||||
std::vector<uint64_t> buffer(collective::GetWorldSize());
|
||||
buffer[collective::GetRank()] = info_.num_col_;
|
||||
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
|
||||
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0);
|
||||
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0ul);
|
||||
if (offset == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -64,9 +64,10 @@ class SimpleDMatrix : public DMatrix {
|
||||
/**
|
||||
* \brief Reindex the features based on a global view.
|
||||
*
|
||||
* In some cases (e.g. vertical federated learning), features are loaded locally with indices
|
||||
* starting from 0. However, all the algorithms assume the features are globally indexed, so we
|
||||
* reindex the features based on the offset needed to obtain the global view.
|
||||
* In some cases (e.g. column-wise data split and vertical federated learning), features are
|
||||
* loaded locally with indices starting from 0. However, all the algorithms assume the features
|
||||
* are globally indexed, so we reindex the features based on the offset needed to obtain the
|
||||
* global view.
|
||||
*/
|
||||
void ReindexFeatures(Context const* ctx);
|
||||
|
||||
|
||||
@ -428,3 +428,21 @@ TEST(SimpleDMatrix, Threads) {
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 0, "")};
|
||||
ASSERT_EQ(p_fmat->Ctx()->Threads(), AllThreadsForTest());
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyColumnSplit() {
|
||||
size_t constexpr kRows {16};
|
||||
size_t constexpr kCols {8};
|
||||
auto dmat =
|
||||
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(false, false, 1, DataSplitMode::kCol);
|
||||
|
||||
ASSERT_EQ(dmat->Info().num_col_, kCols * collective::GetWorldSize());
|
||||
ASSERT_EQ(dmat->Info().num_row_, kRows);
|
||||
ASSERT_EQ(dmat->Info().data_split_mode, DataSplitMode::kCol);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(SimpleDMatrix, ColumnSplit) {
|
||||
auto constexpr kWorldSize{3};
|
||||
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit);
|
||||
}
|
||||
|
||||
@ -378,9 +378,8 @@ void RandomDataGenerator::GenerateCSR(
|
||||
CHECK_EQ(columns->Size(), value->Size());
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label,
|
||||
bool float_label,
|
||||
size_t classes) const {
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(
|
||||
bool with_label, bool float_label, size_t classes, DataSplitMode data_split_mode) const {
|
||||
HostDeviceVector<float> data;
|
||||
HostDeviceVector<bst_row_t> rptrs;
|
||||
HostDeviceVector<bst_feature_t> columns;
|
||||
@ -388,7 +387,7 @@ void RandomDataGenerator::GenerateCSR(
|
||||
data::CSRAdapter adapter(rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), rows_,
|
||||
data.Size(), cols_);
|
||||
std::shared_ptr<DMatrix> out{
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1, "", data_split_mode)};
|
||||
|
||||
if (with_label) {
|
||||
RandomDataGenerator gen{rows_, n_targets_, 0.0f};
|
||||
|
||||
@ -310,9 +310,9 @@ class RandomDataGenerator {
|
||||
void GenerateCSR(HostDeviceVector<float>* value, HostDeviceVector<bst_row_t>* row_ptr,
|
||||
HostDeviceVector<bst_feature_t>* columns) const;
|
||||
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> GenerateDMatrix(bool with_label = false,
|
||||
bool float_label = true,
|
||||
size_t classes = 1) const;
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> GenerateDMatrix(
|
||||
bool with_label = false, bool float_label = true, size_t classes = 1,
|
||||
DataSplitMode data_split_mode = DataSplitMode::kRow) const;
|
||||
|
||||
[[nodiscard]] std::shared_ptr<DMatrix> GenerateSparsePageDMatrix(std::string prefix,
|
||||
bool with_label) const;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user