sync Mar 29

This commit is contained in:
amdsc21
2023-03-30 00:46:50 +02:00
20 changed files with 335 additions and 115 deletions

View File

@@ -704,7 +704,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
void MetaInfo::SynchronizeNumberOfColumns() {
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
if (IsVerticalFederated()) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
@@ -770,6 +770,10 @@ void MetaInfo::Validate(std::int32_t device) const {
void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
bool MetaInfo::IsVerticalFederated() const {
return collective::IsFederated() && IsColumnSplit();
}
using DMatrixThreadLocal =
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;

View File

@@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
proxy->IsColumnSplit(), ctx_.Threads()});
proxy->Info().IsColumnSplit(), ctx_.Threads()});
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];

View File

@@ -74,7 +74,7 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
}
void SimpleDMatrix::ReindexFeatures() {
if (collective::IsFederated() && info_.data_split_mode == DataSplitMode::kCol) {
if (info_.IsVerticalFederated()) {
std::vector<uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = info_.num_col_;
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));