Use DataSplitMode to configure data loading (#8434)
* Use `DataSplitMode` to configure data loading
This commit is contained in:
parent
0d3da9869c
commit
8e76f5f595
@ -40,6 +40,10 @@ enum class DataType : uint8_t {
|
||||
|
||||
enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };
|
||||
|
||||
enum class DataSplitMode : int {
|
||||
kAuto = 0, kCol = 1, kRow = 2, kNone = 3
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Meta information about dataset, always sit in memory.
|
||||
*/
|
||||
@ -537,7 +541,7 @@ class DMatrix {
|
||||
* \brief Load DMatrix from URI.
|
||||
* \param uri The URI of input.
|
||||
* \param silent Whether print information during loading.
|
||||
* \param load_row_split Flag to read in part of rows, divided among the workers in distributed mode.
|
||||
* \param data_split_mode Mode to read in part of the data, divided among the workers in distributed mode.
|
||||
* \param file_format The format type of the file, used for dmlc::Parser::Create.
|
||||
* By default "auto" will be able to load in both local binary file.
|
||||
* \param page_size Page size for external memory.
|
||||
@ -545,7 +549,7 @@ class DMatrix {
|
||||
*/
|
||||
static DMatrix* Load(const std::string& uri,
|
||||
bool silent,
|
||||
bool load_row_split,
|
||||
DataSplitMode data_split_mode,
|
||||
const std::string& file_format = "auto");
|
||||
|
||||
/**
|
||||
@ -678,6 +682,8 @@ inline BatchSet<ExtSparsePage> DMatrix::GetBatches() {
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode);
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_DECLARE_TRAITS(is_pod, xgboost::Entry, true);
|
||||
|
||||
|
||||
@ -207,16 +207,16 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
bool load_row_split = false;
|
||||
auto data_split_mode = DataSplitMode::kNone;
|
||||
if (collective::IsFederated()) {
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
} else if (collective::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
|
||||
load_row_split = true;
|
||||
data_split_mode = DataSplitMode::kRow;
|
||||
}
|
||||
xgboost_CHECK_C_ARG_PTR(fname);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, data_split_mode));
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@ -115,6 +115,7 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
.add_enum("auto", 0)
|
||||
.add_enum("col", 1)
|
||||
.add_enum("row", 2)
|
||||
.add_enum("none", 3)
|
||||
.describe("Data split mode.");
|
||||
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
|
||||
.describe("(Deprecated) Use iteration_begin/iteration_end instead.");
|
||||
@ -157,8 +158,14 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
if (name_pred == "stdout") {
|
||||
save_period = 0;
|
||||
}
|
||||
if (dsplit == 0 && collective::IsDistributed()) {
|
||||
dsplit = 2;
|
||||
if (dsplit == static_cast<int>(DataSplitMode::kAuto)) {
|
||||
if (collective::IsFederated()) {
|
||||
dsplit = static_cast<int>(DataSplitMode::kNone);
|
||||
} else if (collective::IsDistributed()) {
|
||||
dsplit = static_cast<int>(DataSplitMode::kRow);
|
||||
} else {
|
||||
dsplit = static_cast<int>(DataSplitMode::kNone);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -206,18 +213,17 @@ class CLI {
|
||||
}
|
||||
// load in data.
|
||||
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
|
||||
param_.train_path,
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param_.dsplit == 2));
|
||||
param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
static_cast<DataSplitMode>(param_.dsplit)));
|
||||
std::vector<std::shared_ptr<DMatrix>> deval;
|
||||
std::vector<std::shared_ptr<DMatrix>> cache_mats;
|
||||
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
|
||||
cache_mats.push_back(dtrain);
|
||||
for (size_t i = 0; i < param_.eval_data_names.size(); ++i) {
|
||||
deval.emplace_back(std::shared_ptr<DMatrix>(DMatrix::Load(
|
||||
param_.eval_data_paths[i],
|
||||
deval.emplace_back(std::shared_ptr<DMatrix>(
|
||||
DMatrix::Load(param_.eval_data_paths[i],
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param_.dsplit == 2)));
|
||||
static_cast<DataSplitMode>(param_.dsplit))));
|
||||
eval_datasets.push_back(deval.back());
|
||||
cache_mats.push_back(deval.back());
|
||||
}
|
||||
@ -324,7 +330,7 @@ class CLI {
|
||||
std::shared_ptr<DMatrix> dtest(DMatrix::Load(
|
||||
param_.test_path,
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param_.dsplit == 2));
|
||||
static_cast<DataSplitMode>(param_.dsplit)));
|
||||
// load model
|
||||
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for predict";
|
||||
this->ResetLearner({});
|
||||
|
||||
@ -777,8 +777,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
|
||||
const std::string& file_format) {
|
||||
CHECK(data_split_mode == DataSplitMode::kRow || data_split_mode == DataSplitMode::kNone)
|
||||
<< "Precondition violated; data split mode can only be 'row' or 'none'";
|
||||
std::string fname, cache_file;
|
||||
size_t dlm_pos = uri.find('#');
|
||||
if (dlm_pos != std::string::npos) {
|
||||
@ -786,7 +788,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
||||
fname = uri.substr(0, dlm_pos);
|
||||
CHECK_EQ(cache_file.find('#'), std::string::npos)
|
||||
<< "Only one `#` is allowed in file path for cache file specification.";
|
||||
if (load_row_split) {
|
||||
if (data_split_mode == DataSplitMode::kRow) {
|
||||
std::ostringstream os;
|
||||
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
|
||||
for (size_t i = 0; i < cache_shards.size(); ++i) {
|
||||
@ -820,7 +822,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
|
||||
}
|
||||
|
||||
int partid = 0, npart = 1;
|
||||
if (load_row_split) {
|
||||
if (data_split_mode == DataSplitMode::kRow) {
|
||||
partid = collective::GetRank();
|
||||
npart = collective::GetWorldSize();
|
||||
} else {
|
||||
|
||||
@ -53,15 +53,6 @@ namespace {
|
||||
const char* kMaxDeltaStepDefaultValue = "0.7";
|
||||
} // anonymous namespace
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
enum class DataSplitMode : int {
|
||||
kAuto = 0, kCol = 1, kRow = 2
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode);
|
||||
|
||||
namespace xgboost {
|
||||
Learner::~Learner() = default;
|
||||
namespace {
|
||||
@ -298,6 +289,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||
.add_enum("auto", DataSplitMode::kAuto)
|
||||
.add_enum("col", DataSplitMode::kCol)
|
||||
.add_enum("row", DataSplitMode::kRow)
|
||||
.add_enum("none", DataSplitMode::kNone)
|
||||
.describe("Data split mode for distributed training.");
|
||||
DMLC_DECLARE_FIELD(disable_default_eval_metric)
|
||||
.set_default(false)
|
||||
|
||||
@ -89,7 +89,7 @@ inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(
|
||||
}
|
||||
fo.close();
|
||||
return std::shared_ptr<DMatrix>(DMatrix::Load(
|
||||
tmp_file + "#" + tmp_file + ".cache", true, false, "auto"));
|
||||
tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone, "auto"));
|
||||
}
|
||||
|
||||
// Test that elements are approximately equally distributed among bins
|
||||
|
||||
@ -143,7 +143,7 @@ TEST(DMatrix, Uri) {
|
||||
// EXPECT_THROW(dmat.reset(DMatrix::Load(path, false, true)), dmlc::Error);
|
||||
|
||||
std::string uri = path + "?format=csv";
|
||||
dmat.reset(DMatrix::Load(uri, false, true));
|
||||
dmat.reset(DMatrix::Load(uri, false, DataSplitMode::kRow));
|
||||
|
||||
ASSERT_EQ(dmat->Info().num_col_, kCols);
|
||||
ASSERT_EQ(dmat->Info().num_row_, kRows);
|
||||
|
||||
@ -175,7 +175,7 @@ TEST(MetaInfo, LoadQid) {
|
||||
os.set_stream(nullptr);
|
||||
}
|
||||
std::unique_ptr<xgboost::DMatrix> dmat(
|
||||
xgboost::DMatrix::Load(tmp_file, true, false, "libsvm"));
|
||||
xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone, "libsvm"));
|
||||
|
||||
const xgboost::MetaInfo& info = dmat->Info();
|
||||
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};
|
||||
|
||||
@ -15,7 +15,7 @@ TEST(SimpleDMatrix, MetaInfo) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, false);
|
||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone);
|
||||
|
||||
// Test the metadata that was parsed
|
||||
EXPECT_EQ(dmat->Info().num_row_, 2);
|
||||
@ -30,7 +30,7 @@ TEST(SimpleDMatrix, RowAccess) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false, false);
|
||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false, xgboost::DataSplitMode::kNone);
|
||||
|
||||
// Loop over the batches and count the records
|
||||
int64_t row_count = 0;
|
||||
@ -53,7 +53,7 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, false);
|
||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone);
|
||||
|
||||
ASSERT_TRUE(dmat->SingleColBlock());
|
||||
|
||||
@ -304,12 +304,12 @@ TEST(SimpleDMatrix, SaveLoadBinary) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
|
||||
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, xgboost::DataSplitMode::kNone);
|
||||
data::SimpleDMatrix *simple_dmat = dynamic_cast<data::SimpleDMatrix*>(dmat);
|
||||
|
||||
const std::string tmp_binfile = tempdir.path + "/csr_source.binary";
|
||||
simple_dmat->SaveToLocalFile(tmp_binfile);
|
||||
xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, false);
|
||||
xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, xgboost::DataSplitMode::kNone);
|
||||
|
||||
EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_);
|
||||
EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_);
|
||||
|
||||
@ -109,7 +109,7 @@ TEST(SparsePageDMatrix, MetaInfo) {
|
||||
CreateBigTestData(tmp_file, kEntries);
|
||||
|
||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(
|
||||
tmp_file + "#" + tmp_file + ".cache", false, false);
|
||||
tmp_file + "#" + tmp_file + ".cache", false, xgboost::DataSplitMode::kNone);
|
||||
|
||||
// Test the metadata that was parsed
|
||||
EXPECT_EQ(dmat->Info().num_row_, 8ul);
|
||||
@ -137,7 +137,7 @@ TEST(SparsePageDMatrix, ColAccess) {
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
xgboost::DMatrix *dmat =
|
||||
xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false);
|
||||
xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, xgboost::DataSplitMode::kNone);
|
||||
|
||||
// Loop over the batches and assert the data is as expected
|
||||
size_t iter = 0;
|
||||
|
||||
@ -12,7 +12,7 @@ TEST(SparsePageDMatrix, EllpackPage) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false);
|
||||
DMatrix* dmat = DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone);
|
||||
|
||||
// Loop over the batches and assert the data is as expected
|
||||
size_t n = 0;
|
||||
|
||||
@ -528,7 +528,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
|
||||
uri += "#" + tmp_file + ".cache";
|
||||
}
|
||||
std::unique_ptr<DMatrix> dmat(
|
||||
DMatrix::Load(uri, true, false, "auto"));
|
||||
DMatrix::Load(uri, true, DataSplitMode::kNone, "auto"));
|
||||
return dmat;
|
||||
}
|
||||
|
||||
|
||||
@ -99,7 +99,7 @@ TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT
|
||||
const std::string tmp_file = tempdir.path + "/big.libsvm";
|
||||
CreateBigTestData(tmp_file, 50000);
|
||||
std::shared_ptr<DMatrix> dmat(xgboost::DMatrix::Load(
|
||||
tmp_file + "#" + tmp_file + ".cache", true, false, "auto"));
|
||||
tmp_file + "#" + tmp_file + ".cache", true, DataSplitMode::kNone, "auto"));
|
||||
EXPECT_FALSE(dmat->SingleColBlock());
|
||||
size_t num_row = dmat->Info().num_row_;
|
||||
std::vector<bst_float> labels(num_row);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user