Add data split mode to DMatrix MetaInfo (#8568)
This commit is contained in:
@@ -206,17 +206,29 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
auto data_split_mode = DataSplitMode::kNone;
|
||||
if (collective::IsFederated()) {
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
} else if (collective::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
|
||||
data_split_mode = DataSplitMode::kRow;
|
||||
}
|
||||
xgboost_CHECK_C_ARG_PTR(fname);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, data_split_mode));
|
||||
|
||||
Json config{Object()};
|
||||
config["uri"] = std::string{fname};
|
||||
config["silent"] = silent;
|
||||
std::string config_str;
|
||||
Json::Dump(config, &config_str);
|
||||
return XGDMatrixCreateFromURI(config_str.c_str(), out);
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromURI(const char *config, DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
xgboost_CHECK_C_ARG_PTR(config);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
|
||||
auto jconfig = Json::Load(StringView{config});
|
||||
std::string uri = RequiredArg<String>(jconfig, "uri", __func__);
|
||||
auto silent = static_cast<bool>(OptionalArg<Integer, int64_t>(jconfig, "silent", 1));
|
||||
auto data_split_mode =
|
||||
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(jconfig, "data_split_mode", 0));
|
||||
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(uri, silent, data_split_mode));
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
@@ -112,10 +112,8 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
DMLC_DECLARE_FIELD(name_pred).set_default("pred.txt")
|
||||
.describe("Name of the prediction file.");
|
||||
DMLC_DECLARE_FIELD(dsplit).set_default(0)
|
||||
.add_enum("auto", 0)
|
||||
.add_enum("row", 0)
|
||||
.add_enum("col", 1)
|
||||
.add_enum("row", 2)
|
||||
.add_enum("none", 3)
|
||||
.describe("Data split mode.");
|
||||
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
|
||||
.describe("(Deprecated) Use iteration_begin/iteration_end instead.");
|
||||
@@ -158,15 +156,6 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
if (name_pred == "stdout") {
|
||||
save_period = 0;
|
||||
}
|
||||
if (dsplit == static_cast<int>(DataSplitMode::kAuto)) {
|
||||
if (collective::IsFederated()) {
|
||||
dsplit = static_cast<int>(DataSplitMode::kNone);
|
||||
} else if (collective::IsDistributed()) {
|
||||
dsplit = static_cast<int>(DataSplitMode::kRow);
|
||||
} else {
|
||||
dsplit = static_cast<int>(DataSplitMode::kNone);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -783,10 +783,14 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
|
||||
|
||||
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
|
||||
const std::string& file_format) {
|
||||
CHECK(data_split_mode == DataSplitMode::kRow ||
|
||||
data_split_mode == DataSplitMode::kCol ||
|
||||
data_split_mode == DataSplitMode::kNone)
|
||||
<< "Precondition violated; data split mode can only be 'row', 'col', or 'none'";
|
||||
auto need_split = false;
|
||||
if (collective::IsFederated()) {
|
||||
LOG(CONSOLE) << "XGBoost federated mode detected, not splitting data among workers";
|
||||
} else if (collective::IsDistributed()) {
|
||||
LOG(CONSOLE) << "XGBoost distributed mode detected, will split data among workers";
|
||||
need_split = true;
|
||||
}
|
||||
|
||||
std::string fname, cache_file;
|
||||
size_t dlm_pos = uri.find('#');
|
||||
if (dlm_pos != std::string::npos) {
|
||||
@@ -794,7 +798,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
fname = uri.substr(0, dlm_pos);
|
||||
CHECK_EQ(cache_file.find('#'), std::string::npos)
|
||||
<< "Only one `#` is allowed in file path for cache file specification.";
|
||||
if (data_split_mode == DataSplitMode::kRow) {
|
||||
if (need_split && data_split_mode == DataSplitMode::kRow) {
|
||||
std::ostringstream os;
|
||||
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
|
||||
for (size_t i = 0; i < cache_shards.size(); ++i) {
|
||||
@@ -828,7 +832,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
}
|
||||
|
||||
int partid = 0, npart = 1;
|
||||
if (data_split_mode == DataSplitMode::kRow) {
|
||||
if (need_split && data_split_mode == DataSplitMode::kRow) {
|
||||
partid = collective::GetRank();
|
||||
npart = collective::GetWorldSize();
|
||||
} else {
|
||||
@@ -887,7 +891,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
* since partitioned data not knowing the real number of features. */
|
||||
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
|
||||
|
||||
if (data_split_mode == DataSplitMode::kCol) {
|
||||
if (need_split && data_split_mode == DataSplitMode::kCol) {
|
||||
if (!cache_file.empty()) {
|
||||
LOG(FATAL) << "Column-wise data split is not support for external memory.";
|
||||
}
|
||||
@@ -898,6 +902,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
|
||||
delete dmat;
|
||||
return sliced;
|
||||
} else {
|
||||
dmat->Info().data_split_mode = data_split_mode;
|
||||
return dmat;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,7 @@ DMatrix* SimpleDMatrix::SliceCol(std::size_t start, std::size_t size) {
|
||||
out->Info() = this->Info().Copy();
|
||||
out->Info().num_nonzero_ = h_offset.back();
|
||||
}
|
||||
out->Info().data_split_mode = DataSplitMode::kCol;
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@@ -273,8 +273,6 @@ void LearnerModelParam::Copy(LearnerModelParam const& that) {
|
||||
}
|
||||
|
||||
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||
// data split mode, can be row, col, or none.
|
||||
DataSplitMode dsplit {DataSplitMode::kAuto};
|
||||
// flag to disable default metric
|
||||
bool disable_default_eval_metric {false};
|
||||
// FIXME(trivialfis): The following parameters belong to model itself, but can be
|
||||
@@ -284,13 +282,6 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(LearnerTrainParam) {
|
||||
DMLC_DECLARE_FIELD(dsplit)
|
||||
.set_default(DataSplitMode::kAuto)
|
||||
.add_enum("auto", DataSplitMode::kAuto)
|
||||
.add_enum("col", DataSplitMode::kCol)
|
||||
.add_enum("row", DataSplitMode::kRow)
|
||||
.add_enum("none", DataSplitMode::kNone)
|
||||
.describe("Data split mode for distributed training.");
|
||||
DMLC_DECLARE_FIELD(disable_default_eval_metric)
|
||||
.set_default(false)
|
||||
.describe("Flag to disable default metric. Set to >0 to disable");
|
||||
@@ -445,12 +436,6 @@ class LearnerConfiguration : public Learner {
|
||||
|
||||
ConsoleLogger::Configure(args);
|
||||
|
||||
// add additional parameters
|
||||
// These are cosntraints that need to be satisfied.
|
||||
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
|
||||
tparam_.dsplit = DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
// set seed only before the model is initialized
|
||||
if (!initialized || ctx_.seed != old_seed) {
|
||||
common::GlobalRandom().seed(ctx_.seed);
|
||||
@@ -1055,11 +1040,6 @@ class LearnerIO : public LearnerConfiguration {
|
||||
auto n = tparam_.__DICT__();
|
||||
cfg_.insert(n.cbegin(), n.cend());
|
||||
|
||||
// copy dsplit from config since it will not run again during restore
|
||||
if (tparam_.dsplit == DataSplitMode::kAuto && collective::IsDistributed()) {
|
||||
tparam_.dsplit = DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
this->need_configuration_ = true;
|
||||
}
|
||||
|
||||
@@ -1199,16 +1179,6 @@ class LearnerImpl : public LearnerIO {
|
||||
local_map->erase(this);
|
||||
}
|
||||
}
|
||||
// Configuration before data is known.
|
||||
void CheckDataSplitMode() {
|
||||
if (collective::IsDistributed()) {
|
||||
CHECK(tparam_.dsplit != DataSplitMode::kAuto)
|
||||
<< "Precondition violated; dsplit cannot be 'auto' in distributed mode";
|
||||
if (tparam_.dsplit == DataSplitMode::kCol) {
|
||||
LOG(FATAL) << "Column-wise data split is currently not supported.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||
std::string format) override {
|
||||
@@ -1266,7 +1236,6 @@ class LearnerImpl : public LearnerIO {
|
||||
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train.get(), true);
|
||||
|
||||
auto local_cache = this->GetPredictionCache();
|
||||
@@ -1295,7 +1264,6 @@ class LearnerImpl : public LearnerIO {
|
||||
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train.get(), true);
|
||||
|
||||
auto local_cache = this->GetPredictionCache();
|
||||
@@ -1444,19 +1412,14 @@ class LearnerImpl : public LearnerIO {
|
||||
MetaInfo const& info = p_fmat->Info();
|
||||
info.Validate(ctx_.gpu_id);
|
||||
|
||||
auto const row_based_split = [this]() {
|
||||
return tparam_.dsplit == DataSplitMode::kRow || tparam_.dsplit == DataSplitMode::kAuto;
|
||||
};
|
||||
if (row_based_split()) {
|
||||
if (is_training) {
|
||||
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in "
|
||||
"booster.";
|
||||
} else {
|
||||
CHECK_GE(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in "
|
||||
"booster.";
|
||||
}
|
||||
if (is_training) {
|
||||
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in "
|
||||
"booster.";
|
||||
} else {
|
||||
CHECK_GE(learner_model_param_.num_feature, p_fmat->Info().num_col_)
|
||||
<< "Number of columns does not match number of features in "
|
||||
"booster.";
|
||||
}
|
||||
|
||||
if (p_fmat->Info().num_row_ == 0) {
|
||||
|
||||
Reference in New Issue
Block a user