diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 12693dc83..c5ed78dd5 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -40,6 +40,10 @@ enum class DataType : uint8_t { enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 }; +enum class DataSplitMode : int { + kAuto = 0, kCol = 1, kRow = 2, kNone = 3 +}; + /*! * \brief Meta information about dataset, always sit in memory. */ @@ -537,7 +541,7 @@ class DMatrix { * \brief Load DMatrix from URI. * \param uri The URI of input. * \param silent Whether print information during loading. - * \param load_row_split Flag to read in part of rows, divided among the workers in distributed mode. + * \param data_split_mode Mode to read in part of the data, divided among the workers in distributed mode. * \param file_format The format type of the file, used for dmlc::Parser::Create. * By default "auto" will be able to load in both local binary file. * \param page_size Page size for external memory. @@ -545,7 +549,7 @@ class DMatrix { */ static DMatrix* Load(const std::string& uri, bool silent, - bool load_row_split, + DataSplitMode data_split_mode, const std::string& file_format = "auto"); /** @@ -678,6 +682,8 @@ inline BatchSet DMatrix::GetBatches() { } } // namespace xgboost +DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode); + namespace dmlc { DMLC_DECLARE_TRAITS(is_pod, xgboost::Entry, true); diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 87a30283f..c0fb55b1c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -207,16 +207,16 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) { XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) { API_BEGIN(); - bool load_row_split = false; + auto data_split_mode = DataSplitMode::kNone; if (collective::IsFederated()) { LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers"; } else if (collective::IsDistributed()) { LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers"; - load_row_split = true; + data_split_mode = DataSplitMode::kRow; } xgboost_CHECK_C_ARG_PTR(fname); xgboost_CHECK_C_ARG_PTR(out); - *out = new std::shared_ptr(DMatrix::Load(fname, silent != 0, load_row_split)); + *out = new std::shared_ptr(DMatrix::Load(fname, silent != 0, data_split_mode)); API_END(); } diff --git a/src/cli_main.cc b/src/cli_main.cc index de9ae6253..de279f04f 100644 --- a/src/cli_main.cc +++ b/src/cli_main.cc @@ -115,6 +115,7 @@ struct CLIParam : public XGBoostParameter { .add_enum("auto", 0) .add_enum("col", 1) .add_enum("row", 2) + .add_enum("none", 3) .describe("Data split mode."); DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0) .describe("(Deprecated) Use iteration_begin/iteration_end instead."); @@ -157,8 +158,14 @@ struct CLIParam : public XGBoostParameter { if (name_pred == "stdout") { save_period = 0; } - if (dsplit == 0 && collective::IsDistributed()) { - dsplit = 2; + if (dsplit == static_cast(DataSplitMode::kAuto)) { + if (collective::IsFederated()) { + dsplit = static_cast(DataSplitMode::kNone); + } else if (collective::IsDistributed()) { + dsplit = static_cast(DataSplitMode::kRow); + } else { + dsplit = static_cast(DataSplitMode::kNone); + } } } }; @@ -206,18 +213,17 @@ class CLI { } // load in data. std::shared_ptr dtrain(DMatrix::Load( - param_.train_path, - ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), - param_.dsplit == 2)); + param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), + static_cast(param_.dsplit))); std::vector> deval; std::vector> cache_mats; std::vector> eval_datasets; cache_mats.push_back(dtrain); for (size_t i = 0; i < param_.eval_data_names.size(); ++i) { - deval.emplace_back(std::shared_ptr(DMatrix::Load( - param_.eval_data_paths[i], - ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), - param_.dsplit == 2))); + deval.emplace_back(std::shared_ptr( + DMatrix::Load(param_.eval_data_paths[i], + ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), + static_cast(param_.dsplit)))); eval_datasets.push_back(deval.back()); cache_mats.push_back(deval.back()); } @@ -324,7 +330,7 @@ class CLI { std::shared_ptr dtest(DMatrix::Load( param_.test_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), - param_.dsplit == 2)); + static_cast(param_.dsplit))); // load model CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for predict"; this->ResetLearner({}); diff --git a/src/data/data.cc b/src/data/data.cc index 3559ea00f..a4293d7b5 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -777,8 +777,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) { return nullptr; } -DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, +DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode, const std::string& file_format) { + CHECK(data_split_mode == DataSplitMode::kRow || data_split_mode == DataSplitMode::kNone) + << "Precondition violated; data split mode can only be 'row' or 'none'"; std::string fname, cache_file; size_t dlm_pos = uri.find('#'); if (dlm_pos != std::string::npos) { @@ -786,7 +788,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, fname = uri.substr(0, dlm_pos); CHECK_EQ(cache_file.find('#'), std::string::npos) << "Only one `#` is allowed in file path for cache file specification."; - if (load_row_split) { + if (data_split_mode == DataSplitMode::kRow) { std::ostringstream os; std::vector cache_shards = common::Split(cache_file, ':'); for (size_t i = 0; i < cache_shards.size(); ++i) { @@ -820,7 +822,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, } int partid = 0, npart = 1; - if (load_row_split) { + if (data_split_mode == DataSplitMode::kRow) { partid = collective::GetRank(); npart = collective::GetWorldSize(); } else { diff --git a/src/learner.cc b/src/learner.cc index d2386b006..7a8d9a4e1 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -53,15 +53,6 @@ namespace { const char* kMaxDeltaStepDefaultValue = "0.7"; } // anonymous namespace -namespace xgboost { - -enum class DataSplitMode : int { - kAuto = 0, kCol = 1, kRow = 2 -}; -} // namespace xgboost - -DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode); - namespace xgboost { Learner::~Learner() = default; namespace { @@ -298,6 +289,7 @@ struct LearnerTrainParam : public XGBoostParameter { .add_enum("auto", DataSplitMode::kAuto) .add_enum("col", DataSplitMode::kCol) .add_enum("row", DataSplitMode::kRow) + .add_enum("none", DataSplitMode::kNone) .describe("Data split mode for distributed training."); DMLC_DECLARE_FIELD(disable_default_eval_metric) .set_default(false) diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index c8881c158..463506215 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -89,7 +89,7 @@ inline std::shared_ptr GetExternalMemoryDMatrixFromData( } fo.close(); return std::shared_ptr(DMatrix::Load( - tmp_file + "#" + tmp_file + ".cache", true, false, "auto")); + tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone, "auto")); } // Test that elements are approximately equally distributed among bins diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index 871a7f498..51390f62c 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -143,7 +143,7 @@ TEST(DMatrix, Uri) { // EXPECT_THROW(dmat.reset(DMatrix::Load(path, false, true)), dmlc::Error); std::string uri = path + "?format=csv"; - dmat.reset(DMatrix::Load(uri, false, true)); + dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kRow)); ASSERT_EQ(dmat->Info().num_col_, kCols); ASSERT_EQ(dmat->Info().num_row_, kRows); diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 342af77bf..c09b95c7e 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -175,7 +175,7 @@ TEST(MetaInfo, LoadQid) { os.set_stream(nullptr); } std::unique_ptr dmat( - xgboost::DMatrix::Load(tmp_file, true, false, "libsvm")); + xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone, "libsvm")); const xgboost::MetaInfo& info = dmat->Info(); const std::vector expected_group_ptr{0, 4, 8, 12}; diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 266115731..ad545ce14 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -15,7 +15,7 @@ TEST(SimpleDMatrix, MetaInfo) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm"; CreateSimpleTestData(tmp_file); - xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, false); + xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone); // Test the metadata that was parsed EXPECT_EQ(dmat->Info().num_row_, 2); @@ -30,7 +30,7 @@ TEST(SimpleDMatrix, RowAccess) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm"; CreateSimpleTestData(tmp_file); - xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false, false); + xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false, xgboost::DataSplitMode::kNone); // Loop over the batches and count the records int64_t row_count = 0; @@ -53,7 +53,7 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm"; CreateSimpleTestData(tmp_file); - xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, false); + xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone); ASSERT_TRUE(dmat->SingleColBlock()); @@ -304,12 +304,12 @@ TEST(SimpleDMatrix, SaveLoadBinary) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm"; CreateSimpleTestData(tmp_file); - xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false); + xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone); data::SimpleDMatrix *simple_dmat = dynamic_cast(dmat); const std::string tmp_binfile = tempdir.path + "/csr_source.binary"; simple_dmat->SaveToLocalFile(tmp_binfile); - xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, false); + xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, xgboost::DataSplitMode::kNone); EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_); EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_); diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 68171932b..4a9c7562b 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -109,7 +109,7 @@ TEST(SparsePageDMatrix, MetaInfo) { CreateBigTestData(tmp_file, kEntries); xgboost::DMatrix *dmat = xgboost::DMatrix::Load( - tmp_file + "#" + tmp_file + ".cache", false, false); + tmp_file + "#" + tmp_file + ".cache", false, xgboost::DataSplitMode::kNone); // Test the metadata that was parsed EXPECT_EQ(dmat->Info().num_row_, 8ul); @@ -137,7 +137,7 @@ TEST(SparsePageDMatrix, ColAccess) { const std::string tmp_file = tempdir.path + "/simple.libsvm"; CreateSimpleTestData(tmp_file); xgboost::DMatrix *dmat = - xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false); + xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, xgboost::DataSplitMode::kNone); // Loop over the batches and assert the data is as expected size_t iter = 0; diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index 07c86c93f..2dfa5fee1 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -12,7 +12,7 @@ TEST(SparsePageDMatrix, EllpackPage) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm"; CreateSimpleTestData(tmp_file); - DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false); + DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone); // Loop over the batches and assert the data is as expected size_t n = 0; diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 0273d964f..bc7fe6bf5 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -528,7 +528,7 @@ std::unique_ptr CreateSparsePageDMatrixWithRC( uri += "#" + tmp_file + ".cache"; } std::unique_ptr dmat( - DMatrix::Load(uri, true, false, "auto")); + DMatrix::Load(uri, true, DataSplitMode::kNone, "auto")); return dmat; } diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 35dde0c9e..5090fb57c 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -99,7 +99,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT const std::string tmp_file = tempdir.path + "/big.libsvm"; CreateBigTestData(tmp_file, 50000); std::shared_ptr dmat(xgboost::DMatrix::Load( - tmp_file + "#" + tmp_file + ".cache", true, false, "auto")); + tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone, "auto")); EXPECT_FALSE(dmat->SingleColBlock()); size_t num_row = dmat->Info().num_row_; std::vector labels(num_row);