Support vertical federated learning (#8932)

This commit is contained in:
Rong Ou
2023-03-21 23:25:26 -07:00
committed by GitHub
parent 8dc1e4b3ea
commit b240f055d3
23 changed files with 371 additions and 249 deletions

View File

@@ -703,6 +703,14 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
}
void MetaInfo::SynchronizeNumberOfColumns() {
if (collective::IsFederated() && data_split_mode == DataSplitMode::kCol) {
collective::Allreduce<collective::Operation::kSum>(&num_col_, 1);
} else {
collective::Allreduce<collective::Operation::kMax>(&num_col_, 1);
}
}
void MetaInfo::Validate(std::int32_t device) const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
@@ -870,7 +878,7 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
data::FileAdapter adapter(parser.get());
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(),
cache_file);
cache_file, data_split_mode);
} else {
data::FileIterator iter{fname, static_cast<uint32_t>(partid), static_cast<uint32_t>(npart),
file_format};
@@ -906,11 +914,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
LOG(FATAL) << "Encountered parser error:\n" << e.what();
}
/* sync up number of features after matrix loaded.
* partitioned data will fail the train/val validation check
* since partitioned data not knowing the real number of features. */
collective::Allreduce<collective::Operation::kMax>(&dmat->Info().num_col_, 1);
if (need_split && data_split_mode == DataSplitMode::kCol) {
if (!cache_file.empty()) {
LOG(FATAL) << "Column-wise data split is not support for external memory.";
@@ -920,7 +923,6 @@ DMatrix* DMatrix::Load(const std::string& uri, bool silent, DataSplitMode data_s
delete dmat;
return sliced;
} else {
dmat->Info().data_split_mode = data_split_mode;
return dmat;
}
}
@@ -957,39 +959,49 @@ template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
XGDMatrixCallbackNext *next, float missing, int32_t n_threads, std::string);
template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&) {
return new data::SimpleDMatrix(adapter, missing, nthread);
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, const std::string&,
DataSplitMode data_split_mode) {
return new data::SimpleDMatrix(adapter, missing, nthread, data_split_mode);
}
template DMatrix* DMatrix::Create<data::DenseAdapter>(data::DenseAdapter* adapter, float missing,
std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::ArrayAdapter>(data::ArrayAdapter* adapter, float missing,
std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::CSRAdapter>(data::CSRAdapter* adapter, float missing,
std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::CSCAdapter>(data::CSCAdapter* adapter, float missing,
std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::DataTableAdapter>(data::DataTableAdapter* adapter,
float missing, std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::FileAdapter>(data::FileAdapter* adapter, float missing,
std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::CSRArrayAdapter>(data::CSRArrayAdapter* adapter,
float missing, std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::CSCArrayAdapter>(data::CSCArrayAdapter* adapter,
float missing, std::int32_t nthread,
const std::string& cache_prefix);
const std::string& cache_prefix,
DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create(
data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>* adapter,
float missing, int nthread, const std::string& cache_prefix);
float missing, int nthread, const std::string& cache_prefix, DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::RecordBatchesIterAdapter>(
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&);
data::RecordBatchesIterAdapter* adapter, float missing, int nthread, const std::string&,
DataSplitMode data_split_mode);
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
SparsePage transpose;
@@ -1051,6 +1063,13 @@ void SparsePage::SortIndices(int32_t n_threads) {
});
}
void SparsePage::Reindex(uint64_t feature_offset, int32_t n_threads) {
auto& h_data = this->data.HostVector();
common::ParallelFor(h_data.size(), n_threads, [&](auto i) {
h_data[i].index += feature_offset;
});
}
void SparsePage::SortRows(int32_t n_threads) {
auto& h_offset = this->offset.HostVector();
auto& h_data = this->data.HostVector();

View File

@@ -170,17 +170,17 @@ void MetaInfo::SetInfoFromCUDA(Context const& ctx, StringView key, Json array) {
template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix) {
const std::string& cache_prefix, DataSplitMode data_split_mode) {
CHECK_EQ(cache_prefix.size(), 0)
<< "Device memory construction is not currently supported with external "
"memory.";
return new data::SimpleDMatrix(adapter, missing, nthread);
return new data::SimpleDMatrix(adapter, missing, nthread, data_split_mode);
}
template DMatrix* DMatrix::Create<data::CudfAdapter>(
data::CudfAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix);
const std::string& cache_prefix, DataSplitMode data_split_mode);
template DMatrix* DMatrix::Create<data::CupyAdapter>(
data::CupyAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix);
const std::string& cache_prefix, DataSplitMode data_split_mode);
} // namespace xgboost

View File

@@ -190,7 +190,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
// From here on Info() has the correct data shape
Info().num_row_ = accumulated_rows;
Info().num_nonzero_ = nnz;
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
Info().SynchronizeNumberOfColumns();
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
return f > accumulated_rows;
})) << "Something went wrong during iteration.";

View File

@@ -166,7 +166,7 @@ void IterativeDMatrix::InitFromCUDA(DataIterHandle iter_handle, float missing,
iter.Reset();
// Synchronise worker columns
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
info_.SynchronizeNumberOfColumns();
}
BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& param) {

View File

@@ -73,6 +73,19 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
return out;
}
void SimpleDMatrix::ReindexFeatures() {
if (collective::IsFederated() && info_.data_split_mode == DataSplitMode::kCol) {
std::vector<uint64_t> buffer(collective::GetWorldSize());
buffer[collective::GetRank()] = info_.num_col_;
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0);
if (offset == 0) {
return;
}
sparse_page_->Reindex(offset, ctx_.Threads());
}
}
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto begin_iter = BatchIterator<SparsePage>(
@@ -151,7 +164,8 @@ BatchSet<ExtSparsePage> SimpleDMatrix::GetExtBatches(BatchParam const&) {
}
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
DataSplitMode data_split_mode) {
this->ctx_.nthread = nthread;
std::vector<uint64_t> qids;
@@ -217,7 +231,9 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Synchronise worker columns
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
info_.data_split_mode = data_split_mode;
ReindexFeatures();
info_.SynchronizeNumberOfColumns();
if (adapter->NumRows() == kAdapterUnknownSize) {
using IteratorAdapterT
@@ -272,22 +288,31 @@ void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
fo->Write(sparse_page_->data.HostVector());
}
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, int nthread);
template SimpleDMatrix::SimpleDMatrix(ArrayAdapter* adapter, float missing, int nthread);
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing, int nthread);
template SimpleDMatrix::SimpleDMatrix(CSRArrayAdapter* adapter, float missing, int nthread);
template SimpleDMatrix::SimpleDMatrix(CSCArrayAdapter* adapter, float missing, int nthread);
template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing, int nthread);
template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing, int nthread);
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread);
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, int nthread,
DataSplitMode data_split_mode);
template SimpleDMatrix::SimpleDMatrix(ArrayAdapter* adapter, float missing, int nthread,
DataSplitMode data_split_mode);
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing, int nthread,
DataSplitMode data_split_mode);
template SimpleDMatrix::SimpleDMatrix(CSRArrayAdapter* adapter, float missing, int nthread,
DataSplitMode data_split_mode);
template SimpleDMatrix::SimpleDMatrix(CSCArrayAdapter* adapter, float missing, int nthread,
DataSplitMode data_split_mode);
template SimpleDMatrix::SimpleDMatrix(CSCAdapter* adapter, float missing, int nthread,
DataSplitMode data_split_mode);
template SimpleDMatrix::SimpleDMatrix(DataTableAdapter* adapter, float missing, int nthread,
DataSplitMode data_split_mode);
template SimpleDMatrix::SimpleDMatrix(FileAdapter* adapter, float missing, int nthread,
DataSplitMode data_split_mode);
template SimpleDMatrix::SimpleDMatrix(
IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext, XGBoostBatchCSR>
*adapter,
float missing, int nthread);
float missing, int nthread, DataSplitMode data_split_mode);
template <>
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread) {
ctx_.nthread = nthread;
SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, int nthread,
DataSplitMode data_split_mode) {
ctx_.nthread = nthread;
auto& offset_vec = sparse_page_->offset.HostVector();
auto& data_vec = sparse_page_->data.HostVector();
@@ -346,7 +371,10 @@ SimpleDMatrix::SimpleDMatrix(RecordBatchesIterAdapter* adapter, float missing, i
}
// Synchronise worker columns
info_.num_col_ = adapter->NumColumns();
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
info_.data_split_mode = data_split_mode;
ReindexFeatures();
info_.SynchronizeNumberOfColumns();
info_.num_row_ = total_batch_size;
info_.num_nonzero_ = data_vec.size();
CHECK_EQ(offset_vec.back(), info_.num_nonzero_);

View File

@@ -15,7 +15,10 @@ namespace data {
// Current implementation assumes a single batch. More batches can
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread*/) {
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread*/,
DataSplitMode data_split_mode) {
CHECK(data_split_mode != DataSplitMode::kCol)
<< "Column-wise data split is currently not supported on the GPU.";
auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice()
: adapter->DeviceIdx();
CHECK_GE(device, 0);
@@ -35,12 +38,13 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
info_.num_col_ = adapter->NumColumns();
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
info_.data_split_mode = data_split_mode;
info_.SynchronizeNumberOfColumns();
}
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,
int nthread);
int nthread, DataSplitMode data_split_mode);
template SimpleDMatrix::SimpleDMatrix(CupyAdapter* adapter, float missing,
int nthread);
int nthread, DataSplitMode data_split_mode);
} // namespace data
} // namespace xgboost

View File

@@ -22,7 +22,8 @@ class SimpleDMatrix : public DMatrix {
public:
SimpleDMatrix() = default;
template <typename AdapterT>
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
DataSplitMode data_split_mode = DataSplitMode::kRow);
explicit SimpleDMatrix(dmlc::Stream* in_stream);
~SimpleDMatrix() override = default;
@@ -61,6 +62,15 @@ class SimpleDMatrix : public DMatrix {
bool GHistIndexExists() const override { return static_cast<bool>(gradient_index_); }
bool SparsePageExists() const override { return true; }
/**
* \brief Reindex the features based on a global view.
*
* In some cases (e.g. vertical federated learning), features are loaded locally with indices
* starting from 0. However, all the algorithms assume the features are globally indexed, so we
* reindex the features based on the offset needed to obtain the global view.
*/
void ReindexFeatures();
private:
Context ctx_;
};

View File

@@ -96,7 +96,7 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
this->info_.num_col_ = n_features;
this->info_.num_nonzero_ = nnz;
collective::Allreduce<collective::Operation::kMax>(&info_.num_col_, 1);
info_.SynchronizeNumberOfColumns();
CHECK_NE(info_.num_col_, 0);
}

View File

@@ -440,7 +440,7 @@ class LearnerConfiguration : public Learner {
info.Validate(Ctx()->gpu_id);
// We estimate it from input data.
linalg::Tensor<float, 1> base_score;
UsePtr(obj_)->InitEstimation(info, &base_score);
InitEstimation(info, &base_score);
CHECK_EQ(base_score.Size(), 1);
mparam_.base_score = base_score(0);
CHECK(!std::isnan(mparam_.base_score));
@@ -857,6 +857,25 @@ class LearnerConfiguration : public Learner {
mparam_.num_target = n_targets;
}
}
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
// Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
// We assume labels are only available on worker 0, so the estimation is calculated there
// and added to other workers.
if (collective::GetRank() == 0) {
UsePtr(obj_)->InitEstimation(info, base_score);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
} else {
base_score->Reshape(1);
collective::Broadcast(base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(), 0);
}
} else {
UsePtr(obj_)->InitEstimation(info, base_score);
}
}
};
std::string const LearnerConfiguration::kEvalMetric {"eval_metric"}; // NOLINT
@@ -1307,7 +1326,7 @@ class LearnerImpl : public LearnerIO {
monitor_.Stop("PredictRaw");
monitor_.Start("GetGradient");
obj_->GetGradient(predt.predictions, train->Info(), iter, &gpair_);
GetGradient(predt.predictions, train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(gpair_, "Gradients");
@@ -1486,6 +1505,28 @@ class LearnerImpl : public LearnerIO {
}
private:
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
HostDeviceVector<GradientPair>* out_gpair) {
// Special handling for vertical federated learning.
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
// We assume labels are only available on worker 0, so the gradients are calculated there
// and broadcast to other workers.
if (collective::GetRank() == 0) {
obj_->GetGradient(preds, info, iteration, out_gpair);
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
} else {
CHECK_EQ(info.labels.Size(), 0)
<< "In vertical federated learning, labels should only be on the first worker";
out_gpair->Resize(preds.Size());
collective::Broadcast(out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair),
0);
}
} else {
obj_->GetGradient(preds, info, iteration, out_gpair);
}
}
/*! \brief random number transformation seed. */
static int32_t constexpr kRandSeedMagic = 127;
// gradient pairs

View File

@@ -33,7 +33,7 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* b
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
bst_target_t n_targets = this->Targets(info);
linalg::Vector<float> leaf_weight;
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
tree::FitStump(this->ctx_, info, gpair, n_targets, &leaf_weight);
// workaround, we don't support multi-target due to binary model serialization for
// base margin.

View File

@@ -21,7 +21,8 @@
namespace xgboost {
namespace tree {
namespace cpu_impl {
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
void FitStump(Context const* ctx, MetaInfo const& info,
linalg::TensorView<GradientPair const, 2> gpair,
linalg::VectorView<float> out) {
auto n_targets = out.Size();
CHECK_EQ(n_targets, gpair.Shape(1));
@@ -43,8 +44,12 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
}
}
CHECK(h_sum.CContiguous());
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
}
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
@@ -64,7 +69,7 @@ inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda_impl
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out) {
out->SetDevice(ctx->gpu_id);
out->Reshape(n_targets);
@@ -72,7 +77,7 @@ void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
gpair.SetDevice(ctx->gpu_id);
auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets);
ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView())
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
}
} // namespace tree

View File

@@ -16,6 +16,7 @@
#include "../common/common.h" // AssertGPUSupport
#include "xgboost/base.h" // GradientPair
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/linalg.h" // TensorView
@@ -30,7 +31,7 @@ XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) {
/**
* @brief Fit a tree stump as an estimation of base_score.
*/
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out);
} // namespace tree
} // namespace xgboost