Support column-wise data split with in-memory inputs (#9628)
--------- Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -635,22 +635,39 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
||||
}
|
||||
|
||||
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
|
||||
if (size != 0 && this->num_col_ != 0) {
|
||||
if (size != 0 && this->num_col_ != 0 && !IsColumnSplit()) {
|
||||
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
|
||||
CHECK(info);
|
||||
}
|
||||
if (!std::strcmp(key, "feature_type")) {
|
||||
feature_type_names.clear();
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto elem = info[i];
|
||||
feature_type_names.emplace_back(elem);
|
||||
}
|
||||
if (IsColumnSplit()) {
|
||||
feature_type_names = collective::AllgatherStrings(feature_type_names);
|
||||
CHECK_EQ(feature_type_names.size(), num_col_)
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
}
|
||||
auto& h_feature_types = feature_types.HostVector();
|
||||
LoadFeatureType(feature_type_names, &h_feature_types);
|
||||
} else if (!std::strcmp(key, "feature_name")) {
|
||||
feature_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_names.emplace_back(info[i]);
|
||||
if (IsColumnSplit()) {
|
||||
std::vector<std::string> local_feature_names{};
|
||||
auto const rank = collective::GetRank();
|
||||
for (std::size_t i = 0; i < size; ++i) {
|
||||
auto elem = std::to_string(rank) + "." + info[i];
|
||||
local_feature_names.emplace_back(elem);
|
||||
}
|
||||
feature_names = collective::AllgatherStrings(local_feature_names);
|
||||
CHECK_EQ(feature_names.size(), num_col_)
|
||||
<< "Length of " << key << " must be equal to number of columns.";
|
||||
} else {
|
||||
feature_names.clear();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
feature_names.emplace_back(info[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown feature info name: " << key;
|
||||
|
||||
@@ -75,7 +75,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
}
|
||||
|
||||
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
|
||||
if (info_.IsColumnSplit()) {
|
||||
if (info_.IsColumnSplit() && collective::GetWorldSize() > 1) {
|
||||
auto const cols = collective::Allgather(info_.num_col_);
|
||||
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
|
||||
if (offset == 0) {
|
||||
|
||||
Reference in New Issue
Block a user