Support column-wise data split with in-memory inputs (#9628)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -445,8 +445,11 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr, char const *indices, char
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", 0);
|
||||
auto data_split_mode =
|
||||
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
*out = new std::shared_ptr<DMatrix>(
|
||||
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -481,8 +484,11 @@ XGB_DLL int XGDMatrixCreateFromCSC(char const *indptr, char const *indices, char
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
float missing = GetMissing(config);
|
||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||
auto data_split_mode =
|
||||
static_cast<DataSplitMode>(OptionalArg<Integer, int64_t>(config, "data_split_mode", 0));
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||
*out = new std::shared_ptr<DMatrix>(
|
||||
DMatrix::Create(&adapter, missing, n_threads, "", data_split_mode));
|
||||
|
||||
API_END();
|
||||
}
|
||||
@@ -722,6 +728,15 @@ XGB_DLL int XGDMatrixNumNonMissing(DMatrixHandle const handle, xgboost::bst_ulon
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixDataSplitMode(DMatrixHandle handle, bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = CastDMatrixHandle(handle);
|
||||
xgboost_CHECK_C_ARG_PTR(out);
|
||||
*out = static_cast<xgboost::bst_ulong>(p_m->Info().data_split_mode);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDMatrixGetDataAsCSR(DMatrixHandle const handle, char const *config,
|
||||
xgboost::bst_ulong *out_indptr, unsigned *out_indices,
|
||||
float *out_data) {
|
||||
|
||||
@@ -61,6 +61,7 @@ class RabitCommunicator : public Communicator {
|
||||
auto const total_size = per_rank * GetWorldSize();
|
||||
auto const index = per_rank * GetRank();
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(index, per_rank, input);
|
||||
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
|
||||
return result;
|
||||
}
|
||||
@@ -71,7 +72,8 @@ class RabitCommunicator : public Communicator {
|
||||
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
|
||||
auto const begin_index =
|
||||
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
|
||||
auto const size_prev_slice = GetRank() == 0 ? 0 : all_sizes[GetRank() - 1];
|
||||
auto const size_prev_slice =
|
||||
GetRank() == 0 ? all_sizes[GetWorldSize() - 1] : all_sizes[GetRank() - 1];
|
||||
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(begin_index, size_node_slice, input);
|
||||
|
||||
@@ -635,22 +635,39 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
||||
}
|
||||
|
||||
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
|
||||
if (size != 0 && this->num_col_ != 0) {
|
||||
if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) {
|
||||
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
|
||||
CHECK(info);
|
||||
}
|
||||
if (!std::strcmp(key, "feature_type")) {
|
||||
feature_type_names.clear();
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto elem = info[i];
|
||||
feature_type_names.emplace_back(elem);
|
||||
}
|
||||
if (IsColumnSplit()) {
|
||||
feature_type_names = collective::AllgatherStrings(feature_type_names);
|
||||
CHECK_EQ(feature_type_names.size(), num_col_)
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
}
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
LoadFeatureType(feature_type_names, &h_feature_types);
|
||||
} else if (!std::strcmp(key, "feature_name")) {
|
||||
feature_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_names.emplace_back(info[i]);
|
||||
if (IsColumnSplit()) {
|
||||
std::vector<std::string> local_feature_names{};
|
||||
auto const rank = collective::GetRank();
|
||||
for (std::size_t i = 0; i < size; ++i) {
|
||||
auto elem = std::to_string(rank) + "." + info[i];
|
||||
local_feature_names.emplace_back(elem);
|
||||
}
|
||||
feature_names = collective::AllgatherStrings(local_feature_names);
|
||||
CHECK_EQ(feature_names.size(), num_col_)
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
} else {
|
||||
feature_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_names.emplace_back(info[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown feature info name: " << key;
|
||||
|
||||
@@ -75,7 +75,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
}
|
||||
|
||||
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
|
||||
if (info_.IsColumnSplit()) {
|
||||
if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) {
|
||||
auto const cols = collective::Allgather(info_.num_col_);
|
||||
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
|
||||
if (offset == 0) {
|
||||
|
||||
Reference in New Issue
Block a user