Use DataSplitMode to configure data loading (#8434)

* Use `DataSplitMode` to configure data loading
This commit is contained in:
Rong Ou
2022-11-08 00:21:50 -08:00
committed by GitHub
parent 0d3da9869c
commit 8e76f5f595
13 changed files with 46 additions and 40 deletions

View File

@@ -777,8 +777,10 @@ DMatrix *TryLoadBinary(std::string fname, bool silent) {
return nullptr;
}
DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_split_mode,
const std::string& file_format) {
CHECK(data_split_mode == DataSplitMode::kRow || data_split_mode == DataSplitMode::kNone)
<< "Precondition violated; data split mode can only be 'row' or 'none'";
std::string fname, cache_file;
size_t dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) {
@@ -786,7 +788,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
fname = uri.substr(0, dlm_pos);
CHECK_EQ(cache_file.find('#'), std::string::npos)
<< "Only one `#` is allowed in file path for cache file specification.";
if (load_row_split) {
if (data_split_mode == DataSplitMode::kRow) {
std::ostringstream os;
std::vector<std::string> cache_shards = common::Split(cache_file, ':');
for (size_t i = 0; i < cache_shards.size(); ++i) {
@@ -820,7 +822,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split,
}
int partid = 0, npart = 1;
if (load_row_split) {
if (data_split_mode == DataSplitMode::kRow) {
partid = collective::GetRank();
npart = collective::GetWorldSize();
} else {