diff --git a/include/xgboost/data.h b/include/xgboost/data.h index c232819f9..04b489d8b 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -559,8 +559,7 @@ class DMatrix { * * \param uri The URI of input. * \param silent Whether print information during loading. - * \param data_split_mode In distributed mode, split the input according this mode; otherwise, - * it's just an indicator on how the input was split beforehand. + * \param data_split_mode Indicate how the data was split beforehand. * \return The created DMatrix. */ static DMatrix* Load(const std::string& uri, bool silent = true, diff --git a/src/data/data.cc b/src/data/data.cc index 4a2bef6be..3c190a90b 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -729,7 +729,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col } void MetaInfo::SynchronizeNumberOfColumns() { - if (IsVerticalFederated()) { + if (IsColumnSplit()) { collective::Allreduce(&num_col_, 1); } else { collective::Allreduce(&num_col_, 1); @@ -850,14 +850,6 @@ DMatrix* TryLoadBinary(std::string fname, bool silent) { } // namespace DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode) { - auto need_split = false; - if (collective::IsFederated()) { - LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers"; - } else if (collective::IsDistributed()) { - LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers"; - need_split = true; - } - std::string fname, cache_file; auto dlm_pos = uri.find('#'); if (dlm_pos != std::string::npos) { @@ -865,24 +857,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s fname = uri.substr(0, dlm_pos); CHECK_EQ(cache_file.find('#'), std::string::npos) << "Only one `#` is allowed in file path for cache file specification."; - if (need_split && data_split_mode == DataSplitMode::kRow) { - std::ostringstream os; - std::vector cache_shards = common::Split(cache_file, ':'); - for (size_t i = 0; i < cache_shards.size(); ++i) { - size_t pos = cache_shards[i].rfind('.'); - if (pos == std::string::npos) { - os << cache_shards[i] << ".r" << collective::GetRank() << "-" - << collective::GetWorldSize(); - } else { - os << cache_shards[i].substr(0, pos) << ".r" << collective::GetRank() << "-" - << collective::GetWorldSize() << cache_shards[i].substr(pos, cache_shards[i].length()); - } - if (i + 1 != cache_shards.size()) { - os << ':'; - } - } - cache_file = os.str(); - } } else { fname = uri; } @@ -894,19 +868,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s } int partid = 0, npart = 1; - if (need_split && data_split_mode == DataSplitMode::kRow) { - partid = collective::GetRank(); - npart = collective::GetWorldSize(); - } else { - // test option to load in part - npart = 1; - } - - if (npart != 1) { - LOG(CONSOLE) << "Load part of data " << partid << " of " << npart << " parts"; - } - - DMatrix* dmat{nullptr}; + DMatrix* dmat{}; if (cache_file.empty()) { fname = data::ValidateFileFormat(fname); @@ -916,6 +878,8 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s dmat = DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), Context{}.Threads(), cache_file, data_split_mode); } else { + CHECK(data_split_mode != DataSplitMode::kCol) + << "Column-wise data split is not supported for external memory."; data::FileIterator iter{fname, static_cast(partid), static_cast(npart)}; dmat = new data::SparsePageDMatrix{&iter, iter.Proxy(), @@ -926,17 +890,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s cache_file}; } - if (need_split && data_split_mode == DataSplitMode::kCol) { - if (!cache_file.empty()) { - LOG(FATAL) << "Column-wise data split is not support for external memory."; - } - LOG(CONSOLE) << "Splitting data by column"; - auto* sliced = dmat->SliceCol(npart, partid); - delete dmat; - return sliced; - } else { - return dmat; - } + return dmat; } template buffer(collective::GetWorldSize()); buffer[collective::GetRank()] = info_.num_col_; collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t)); - auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0); + auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0ul); if (offset == 0) { return; } diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index d6164894a..5b5bb2bfb 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -64,9 +64,10 @@ class SimpleDMatrix : public DMatrix { /** * \brief Reindex the features based on a global view. * - * In some cases (e.g. vertical federated learning), features are loaded locally with indices - * starting from 0. However, all the algorithms assume the features are globally indexed, so we - * reindex the features based on the offset needed to obtain the global view. + * In some cases (e.g. column-wise data split and vertical federated learning), features are + * loaded locally with indices starting from 0. However, all the algorithms assume the features + * are globally indexed, so we reindex the features based on the offset needed to obtain the + * global view. */ void ReindexFeatures(Context const* ctx); diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index e4d5f2672..fa4165796 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -428,3 +428,21 @@ TEST(SimpleDMatrix, Threads) { DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 0, "")}; ASSERT_EQ(p_fmat->Ctx()->Threads(), AllThreadsForTest()); } + +namespace { +void VerifyColumnSplit() { + size_t constexpr kRows {16}; + size_t constexpr kCols {8}; + auto dmat = + RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(false, false, 1, DataSplitMode::kCol); + + ASSERT_EQ(dmat->Info().num_col_, kCols * collective::GetWorldSize()); + ASSERT_EQ(dmat->Info().num_row_, kRows); + ASSERT_EQ(dmat->Info().data_split_mode, DataSplitMode::kCol); +} +} // anonymous namespace + +TEST(SimpleDMatrix, ColumnSplit) { + auto constexpr kWorldSize{3}; + RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit); +} diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 604c4d30a..97db9dbd8 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -378,9 +378,8 @@ void RandomDataGenerator::GenerateCSR( CHECK_EQ(columns->Size(), value->Size()); } -[[nodiscard]] std::shared_ptr RandomDataGenerator::GenerateDMatrix(bool with_label, - bool float_label, - size_t classes) const { +[[nodiscard]] std::shared_ptr RandomDataGenerator::GenerateDMatrix( + bool with_label, bool float_label, size_t classes, DataSplitMode data_split_mode) const { HostDeviceVector data; HostDeviceVector rptrs; HostDeviceVector columns; @@ -388,7 +387,7 @@ void RandomDataGenerator::GenerateCSR( data::CSRAdapter adapter(rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), rows_, data.Size(), cols_); std::shared_ptr out{ - DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 1)}; + DMatrix::Create(&adapter, std::numeric_limits::quiet_NaN(), 1, "", data_split_mode)}; if (with_label) { RandomDataGenerator gen{rows_, n_targets_, 0.0f}; diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index a26669b7d..82a55450e 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -310,9 +310,9 @@ class RandomDataGenerator { void GenerateCSR(HostDeviceVector* value, HostDeviceVector* row_ptr, HostDeviceVector* columns) const; - [[nodiscard]] std::shared_ptr GenerateDMatrix(bool with_label = false, - bool float_label = true, - size_t classes = 1) const; + [[nodiscard]] std::shared_ptr GenerateDMatrix( + bool with_label = false, bool float_label = true, size_t classes = 1, + DataSplitMode data_split_mode = DataSplitMode::kRow) const; [[nodiscard]] std::shared_ptr GenerateSparsePageDMatrix(std::string prefix, bool with_label) const;