Support column-wise data split with in-memory inputs (#9628)

---------

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
Rong Ou
2023-10-16 21:16:39 -07:00
committed by GitHub
parent 4d1607eefd
commit da6803b75b
12 changed files with 307 additions and 27 deletions

View File

@@ -445,8 +445,11 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
API_END();
}
@@ -481,8 +484,11 @@ XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char
auto config = Json::Load(StringView{c_json_config});
float missing = GetMissing(config);
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
auto data_split_mode =
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
xgboost_CHECK_C_ARG_PTR(out);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
*out = new std::shared_ptr<DMatrix>(
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
API_END();
}
@@ -722,6 +728,15 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulon
API_END();
}
XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
auto p_m = CastDMatrixHandle(handle);
xgboost_CHECK_C_ARG_PTR(out);
*out = static_cast<xgboost::bst_ulong>(p_m->Info().data_split_mode);
API_END();
}
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
float *out_data) {

View File

@@ -61,6 +61,7 @@ class RabitCommunicator : public Communicator {
auto const total_size = per_rank * GetWorldSize();
auto const index = per_rank * GetRank();
std::string result(total_size, '\0');
result.replace(index, per_rank, input);
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
return result;
}
@@ -71,7 +72,8 @@ class RabitCommunicator : public Communicator {
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
auto const begin_index =
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
auto const size_prev_slice = GetRank() == 0 ? 0 : all_sizes[GetRank() - 1];
auto const size_prev_slice =
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
std::string result(total_size, '\0');
result.replace(begin_index, size_node_slice, input);

View File

@@ -635,22 +635,39 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
}
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
if (size != 0 && this->num_col_ != 0) {
if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) {
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
CHECK(info);
}
if (!std::strcmp(key, "feature_type")) {
feature_type_names.clear();
auto& h_feature_types = feature_types.HostVector();
for (size_t i = 0; i < size; ++i) {
auto elem = info[i];
feature_type_names.emplace_back(elem);
}
if (IsColumnSplit()) {
feature_type_names = collective::AllgatherStrings(feature_type_names);
CHECK_EQ(feature_type_names.size(), num_col_)
<< "Length of " << key << " must be equal to number of columns.";
}
auto& h_feature_types = feature_types.HostVector();
LoadFeatureType(feature_type_names, &h_feature_types);
} else if (!std::strcmp(key, "feature_name")) {
feature_names.clear();
for (size_t i = 0; i < size; ++i) {
feature_names.emplace_back(info[i]);
if (IsColumnSplit()) {
std::vector<std::string> local_feature_names{};
auto const rank = collective::GetRank();
for (std::size_t i = 0; i < size; ++i) {
auto elem = std::to_string(rank) + "." + info[i];
local_feature_names.emplace_back(elem);
}
feature_names = collective::AllgatherStrings(local_feature_names);
CHECK_EQ(feature_names.size(), num_col_)
<< "Length of " << key << " must be equal to number of columns.";
} else {
feature_names.clear();
for (size_t i = 0; i < size; ++i) {
feature_names.emplace_back(info[i]);
}
}
} else {
LOG(FATAL) << "Unknown feature info name: " << key;

View File

@@ -75,7 +75,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
}
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
if (info_.IsColumnSplit()) {
if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) {
auto const cols = collective::Allgather(info_.num_col_);
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
if (offset == 0) {