Use DataSplitMode to configure data loading (#8434)

* Use `DataSplitMode` to configure data loading
This commit is contained in:
Rong Ou 2022-11-08 00:21:50 -08:00 committed by GitHub
parent 0d3da9869c
commit 8e76f5f595
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 46 additions and 40 deletions

View File

@ -40,6 +40,10 @@ enum class DataType : uint8_t {
enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 }; enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };
enum class DataSplitMode : int {
kAuto = 0, kCol = 1, kRow = 2, kNone = 3
};
/*! /*!
* \brief Meta information about dataset, always sit in memory. * \brief Meta information about dataset, always sit in memory.
*/ */
@ -537,7 +541,7 @@ class DMatrix {
* \brief Load DMatrix from URI. * \brief Load DMatrix from URI.
* \param uri The URI of input. * \param uri The URI of input.
* \param silent Whether print information during loading. * \param silent Whether print information during loading.
* \param load_row_split Flag to read in part of rows, divided among the workers in distributed mode. * \param data_split_mode Mode to read in part of the data, divided among the workers in distributed mode.
* \param file_format The format type of the file, used for dmlc::Parser::Create. * \param file_format The format type of the file, used for dmlc::Parser::Create.
* By default "auto" will be able to load in both local binary file. * By default "auto" will be able to load in both local binary file.
* \param page_size Page size for external memory. * \param page_size Page size for external memory.
@ -545,7 +549,7 @@ class DMatrix {
*/ */
static DMatrix* Load(const std::string& uri, static DMatrix* Load(const std::string& uri,
bool silent, bool silent,
bool load_row_split, DataSplitMode data_split_mode,
const std::string& file_format = "auto"); const std::string& file_format = "auto");
/** /**
@ -678,6 +682,8 @@ inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
} }
} // namespace xgboost } // namespace xgboost
DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode);
namespace dmlc { namespace dmlc {
DMLC_DECLARE_TRAITS(is_pod, xgboost::Entry, true); DMLC_DECLARE_TRAITS(is_pod, xgboost::Entry, true);

View File

@ -207,16 +207,16 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) { XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
API_BEGIN(); API_BEGIN();
bool load_row_split = false; auto data_split_mode = DataSplitMode::kNone;
if (collective::IsFederated()) { if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers"; LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
} else if (collective::IsDistributed()) { } else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers"; LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
load_row_split = true; data_split_mode = DataSplitMode::kRow;
} }
xgboost_CHECK_C_ARG_PTR(fname); xgboost_CHECK_C_ARG_PTR(fname);
xgboost_CHECK_C_ARG_PTR(out); xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split)); *out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, data_split_mode));
API_END(); API_END();
} }

View File

@ -115,6 +115,7 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
.add_enum("auto", 0) .add_enum("auto", 0)
.add_enum("col", 1) .add_enum("col", 1)
.add_enum("row", 2) .add_enum("row", 2)
.add_enum("none", 3)
.describe("Data split mode."); .describe("Data split mode.");
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0) DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
.describe("(Deprecated) Use iteration_begin/iteration_end instead."); .describe("(Deprecated) Use iteration_begin/iteration_end instead.");
@ -157,8 +158,14 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
if (name_pred == "stdout") { if (name_pred == "stdout") {
save_period = 0; save_period = 0;
} }
if (dsplit == 0 && collective::IsDistributed()) { if (dsplit == static_cast<int>(DataSplitMode::kAuto)) {
dsplit = 2; if (collective::IsFederated()) {
dsplit = static_cast<int>(DataSplitMode::kNone);
} else if (collective::IsDistributed()) {
dsplit = static_cast<int>(DataSplitMode::kRow);
} else {
dsplit = static_cast<int>(DataSplitMode::kNone);
}
} }
} }
}; };
@ -206,18 +213,17 @@ class CLI {
} }
// load in data. // load in data.
std::shared_ptr<DMatrix> dtrain(DMatrix::Load( std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
param_.train_path, param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), static_cast<DataSplitMode>(param_.dsplit)));
param_.dsplit == 2));
std::vector<std::shared_ptr<DMatrix>> deval; std::vector<std::shared_ptr<DMatrix>> deval;
std::vector<std::shared_ptr<DMatrix>> cache_mats; std::vector<std::shared_ptr<DMatrix>> cache_mats;
std::vector<std::shared_ptr<DMatrix>> eval_datasets; std::vector<std::shared_ptr<DMatrix>> eval_datasets;
cache_mats.push_back(dtrain); cache_mats.push_back(dtrain);
for (size_t i = 0; i < param_.eval_data_names.size(); ++i) { for (size_t i = 0; i < param_.eval_data_names.size(); ++i) {
deval.emplace_back(std::shared_ptr<DMatrix>(DMatrix::Load( deval.emplace_back(std::shared_ptr<DMatrix>(
param_.eval_data_paths[i], DMatrix::Load(param_.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2))); static_cast<DataSplitMode>(param_.dsplit))));
eval_datasets.push_back(deval.back()); eval_datasets.push_back(deval.back());
cache_mats.push_back(deval.back()); cache_mats.push_back(deval.back());
} }
@ -324,7 +330,7 @@ class CLI {
std::shared_ptr<DMatrix> dtest(DMatrix::Load( std::shared_ptr<DMatrix> dtest(DMatrix::Load(
param_.test_path, param_.test_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(), ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2)); static_cast<DataSplitMode>(param_.dsplit)));
// load model // load model
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for predict"; CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for predict";
this->ResetLearner({}); this->ResetLearner({});

View File

@ -777,8 +777,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
return nullptr; return nullptr;
} }
DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
const std::string& file_format) { const std::string& file_format) {
CHECK(data_split_mode == DataSplitMode::kRow || data_split_mode == DataSplitMode::kNone)
<< "Precondition violated; data split mode can only be 'row' or 'none'";
std::string fname, cache_file; std::string fname, cache_file;
size_t dlm_pos = uri.find('#'); size_t dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) { if (dlm_pos != std::string::npos) {
@ -786,7 +788,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
fname = uri.substr(0, dlm_pos); fname = uri.substr(0, dlm_pos);
CHECK_EQ(cache_file.find('#'), std::string::npos) CHECK_EQ(cache_file.find('#'), std::string::npos)
<< "Only one `#` is allowed in file path for cache file specification."; << "Only one `#` is allowed in file path for cache file specification.";
if (load_row_split) { if (data_split_mode == DataSplitMode::kRow) {
std::ostringstream os; std::ostringstream os;
std::vector<std::string> cache_shards = common::Split(cache_file, ':'); std::vector<std::string> cache_shards = common::Split(cache_file, ':');
for (size_t i = 0; i < cache_shards.size(); ++i) { for (size_t i = 0; i < cache_shards.size(); ++i) {
@ -820,7 +822,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
} }
int partid = 0, npart = 1; int partid = 0, npart = 1;
if (load_row_split) { if (data_split_mode == DataSplitMode::kRow) {
partid = collective::GetRank(); partid = collective::GetRank();
npart = collective::GetWorldSize(); npart = collective::GetWorldSize();
} else { } else {

View File

@ -53,15 +53,6 @@ namespace {
const char* kMaxDeltaStepDefaultValue = "0.7"; const char* kMaxDeltaStepDefaultValue = "0.7";
} // anonymous namespace } // anonymous namespace
namespace xgboost {
enum class DataSplitMode : int {
kAuto = 0, kCol = 1, kRow = 2
};
} // namespace xgboost
DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode);
namespace xgboost { namespace xgboost {
Learner::~Learner() = default; Learner::~Learner() = default;
namespace { namespace {
@ -298,6 +289,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
.add_enum("auto", DataSplitMode::kAuto) .add_enum("auto", DataSplitMode::kAuto)
.add_enum("col", DataSplitMode::kCol) .add_enum("col", DataSplitMode::kCol)
.add_enum("row", DataSplitMode::kRow) .add_enum("row", DataSplitMode::kRow)
.add_enum("none", DataSplitMode::kNone)
.describe("Data split mode for distributed training."); .describe("Data split mode for distributed training.");
DMLC_DECLARE_FIELD(disable_default_eval_metric) DMLC_DECLARE_FIELD(disable_default_eval_metric)
.set_default(false) .set_default(false)

View File

@ -89,7 +89,7 @@ inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(
} }
fo.close(); fo.close();
return std::shared_ptr<DMatrix>(DMatrix::Load( return std::shared_ptr<DMatrix>(DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", true, false, "auto")); tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone, "auto"));
} }
// Test that elements are approximately equally distributed among bins // Test that elements are approximately equally distributed among bins

View File

@ -143,7 +143,7 @@ TEST(DMatrix, Uri) {
// EXPECT_THROW(dmat.reset(DMatrix::Load(path, false, true)), dmlc::Error); // EXPECT_THROW(dmat.reset(DMatrix::Load(path, false, true)), dmlc::Error);
std::string uri = path + "?format=csv"; std::string uri = path + "?format=csv";
dmat.reset(DMatrix::Load(uri, false, true)); dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kRow));
ASSERT_EQ(dmat->Info().num_col_, kCols); ASSERT_EQ(dmat->Info().num_col_, kCols);
ASSERT_EQ(dmat->Info().num_row_, kRows); ASSERT_EQ(dmat->Info().num_row_, kRows);

View File

@ -175,7 +175,7 @@ TEST(MetaInfo, LoadQid) {
os.set_stream(nullptr); os.set_stream(nullptr);
} }
std::unique_ptr<xgboost::DMatrix> dmat( std::unique_ptr<xgboost::DMatrix> dmat(
xgboost::DMatrix::Load(tmp_file, true, false, "libsvm")); xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone, "libsvm"));
const xgboost::MetaInfo& info = dmat->Info(); const xgboost::MetaInfo& info = dmat->Info();
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12}; const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};

View File

@ -15,7 +15,7 @@ TEST(SimpleDMatrix, MetaInfo) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, false); xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone);
// Test the metadata that was parsed // Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 2); EXPECT_EQ(dmat->Info().num_row_, 2);
@ -30,7 +30,7 @@ TEST(SimpleDMatrix, RowAccess) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false, false); xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false, xgboost::DataSplitMode::kNone);
// Loop over the batches and count the records // Loop over the batches and count the records
int64_t row_count = 0; int64_t row_count = 0;
@ -53,7 +53,7 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, false); xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone);
ASSERT_TRUE(dmat->SingleColBlock()); ASSERT_TRUE(dmat->SingleColBlock());
@ -304,12 +304,12 @@ TEST(SimpleDMatrix, SaveLoadBinary) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false); xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone);
data::SimpleDMatrix *simple_dmat = dynamic_cast<data::SimpleDMatrix*>(dmat); data::SimpleDMatrix *simple_dmat = dynamic_cast<data::SimpleDMatrix*>(dmat);
const std::string tmp_binfile = tempdir.path + "/csr_source.binary"; const std::string tmp_binfile = tempdir.path + "/csr_source.binary";
simple_dmat->SaveToLocalFile(tmp_binfile); simple_dmat->SaveToLocalFile(tmp_binfile);
xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, false); xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, xgboost::DataSplitMode::kNone);
EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_); EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_);
EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_); EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_);

View File

@ -109,7 +109,7 @@ TEST(SparsePageDMatrix, MetaInfo) {
CreateBigTestData(tmp_file, kEntries); CreateBigTestData(tmp_file, kEntries);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load( xgboost::DMatrix *dmat = xgboost::DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", false, false); tmp_file + "#" + tmp_file + ".cache", false, xgboost::DataSplitMode::kNone);
// Test the metadata that was parsed // Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 8ul); EXPECT_EQ(dmat->Info().num_row_, 8ul);
@ -137,7 +137,7 @@ TEST(SparsePageDMatrix, ColAccess) {
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
xgboost::DMatrix *dmat = xgboost::DMatrix *dmat =
xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false); xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, xgboost::DataSplitMode::kNone);
// Loop over the batches and assert the data is as expected // Loop over the batches and assert the data is as expected
size_t iter = 0; size_t iter = 0;

View File

@ -12,7 +12,7 @@ TEST(SparsePageDMatrix, EllpackPage) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); CreateSimpleTestData(tmp_file);
DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false); DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone);
// Loop over the batches and assert the data is as expected // Loop over the batches and assert the data is as expected
size_t n = 0; size_t n = 0;

View File

@ -528,7 +528,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
uri += "#" + tmp_file + ".cache"; uri += "#" + tmp_file + ".cache";
} }
std::unique_ptr<DMatrix> dmat( std::unique_ptr<DMatrix> dmat(
DMatrix::Load(uri, true, false, "auto")); DMatrix::Load(uri, true, DataSplitMode::kNone, "auto"));
return dmat; return dmat;
} }

View File

@ -99,7 +99,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT
const std::string tmp_file = tempdir.path + "/big.libsvm"; const std::string tmp_file = tempdir.path + "/big.libsvm";
CreateBigTestData(tmp_file, 50000); CreateBigTestData(tmp_file, 50000);
std::shared_ptr<DMatrix> dmat(xgboost::DMatrix::Load( std::shared_ptr<DMatrix> dmat(xgboost::DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", true, false, "auto")); tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone, "auto"));
EXPECT_FALSE(dmat->SingleColBlock()); EXPECT_FALSE(dmat->SingleColBlock());
size_t num_row = dmat->Info().num_row_; size_t num_row = dmat->Info().num_row_;
std::vector<bst_float> labels(num_row); std::vector<bst_float> labels(num_row);