Use DataSplitMode to configure data loading (#8434)

* Use `DataSplitMode` to configure data loading
This commit is contained in:
Rong Ou
2022-11-08 00:21:50 -08:00
committed by GitHub
parent 0d3da9869c
commit 8e76f5f595
13 changed files with 46 additions and 40 deletions

View File

@@ -207,16 +207,16 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
API_BEGIN();
bool load_row_split = false;
auto data_split_mode = DataSplitMode::kNone;
if (collective::IsFederated()) {
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
} else if (collective::IsDistributed()) {
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
load_row_split = true;
data_split_mode = DataSplitMode::kRow;
}
xgboost_CHECK_C_ARG_PTR(fname);
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, data_split_mode));
API_END();
}

View File

@@ -115,6 +115,7 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
.add_enum("auto", 0)
.add_enum("col", 1)
.add_enum("row", 2)
.add_enum("none", 3)
.describe("Data split mode.");
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
.describe("(Deprecated) Use iteration_begin/iteration_end instead.");
@@ -157,8 +158,14 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
if (name_pred == "stdout") {
save_period = 0;
}
if (dsplit == 0 && collective::IsDistributed()) {
dsplit = 2;
if (dsplit == static_cast<int>(DataSplitMode::kAuto)) {
if (collective::IsFederated()) {
dsplit = static_cast<int>(DataSplitMode::kNone);
} else if (collective::IsDistributed()) {
dsplit = static_cast<int>(DataSplitMode::kRow);
} else {
dsplit = static_cast<int>(DataSplitMode::kNone);
}
}
}
};
@@ -206,18 +213,17 @@ class CLI {
}
// load in data.
std::shared_ptr<DMatrix> dtrain(DMatrix::Load(
param_.train_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2));
param_.train_path, ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
static_cast<DataSplitMode>(param_.dsplit)));
std::vector<std::shared_ptr<DMatrix>> deval;
std::vector<std::shared_ptr<DMatrix>> cache_mats;
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
cache_mats.push_back(dtrain);
for (size_t i = 0; i < param_.eval_data_names.size(); ++i) {
deval.emplace_back(std::shared_ptr<DMatrix>(DMatrix::Load(
param_.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2)));
deval.emplace_back(std::shared_ptr<DMatrix>(
DMatrix::Load(param_.eval_data_paths[i],
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
static_cast<DataSplitMode>(param_.dsplit))));
eval_datasets.push_back(deval.back());
cache_mats.push_back(deval.back());
}
@@ -324,7 +330,7 @@ class CLI {
std::shared_ptr<DMatrix> dtest(DMatrix::Load(
param_.test_path,
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
param_.dsplit == 2));
static_cast<DataSplitMode>(param_.dsplit)));
// load model
CHECK_NE(param_.model_in, CLIParam::kNull) << "Must specify model_in for predict";
this->ResetLearner({});

View File

@@ -777,8 +777,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
return nullptr;
}
DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
const std::string& file_format) {
CHECK(data_split_mode == DataSplitMode::kRow || data_split_mode == DataSplitMode::kNone)
<< "Precondition violated; data split mode can only be 'row' or 'none'";
std::string fname, cache_file;
size_t dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) {
@@ -786,7 +788,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
fname = uri.substr(0, dlm_pos);
CHECK_EQ(cache_file.find('#'), std::string::npos)
<< "Only one `#` is allowed in file path for cache file specification.";
if (load_row_split) {
if (data_split_mode == DataSplitMode::kRow) {
std::ostringstream os;
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
for (size_t i = 0; i < cache_shards.size(); ++i) {
@@ -820,7 +822,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
}
int partid = 0, npart = 1;
if (load_row_split) {
if (data_split_mode == DataSplitMode::kRow) {
partid = collective::GetRank();
npart = collective::GetWorldSize();
} else {

View File

@@ -53,15 +53,6 @@ namespace {
const char* kMaxDeltaStepDefaultValue = "0.7";
} // anonymous namespace
namespace xgboost {
enum class DataSplitMode : int {
kAuto = 0, kCol = 1, kRow = 2
};
} // namespace xgboost
DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode);
namespace xgboost {
Learner::~Learner() = default;
namespace {
@@ -298,6 +289,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
.add_enum("auto", DataSplitMode::kAuto)
.add_enum("col", DataSplitMode::kCol)
.add_enum("row", DataSplitMode::kRow)
.add_enum("none", DataSplitMode::kNone)
.describe("Data split mode for distributed training.");
DMLC_DECLARE_FIELD(disable_default_eval_metric)
.set_default(false)