Support vertical federated learning (#8932)

This commit is contained in:
Rong Ou
2023-03-21 23:25:26 -07:00
committed by GitHub
parent 8dc1e4b3ea
commit b240f055d3
23 changed files with 371 additions and 249 deletions

View File

@@ -21,7 +21,8 @@
namespace xgboost {
namespace tree {
namespace cpu_impl {
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
void FitStump(Context const* ctx, MetaInfo const& info,
linalg::TensorView<GradientPair const, 2> gpair,
linalg::VectorView<float> out) {
auto n_targets = out.Size();
CHECK_EQ(n_targets, gpair.Shape(1));
@@ -43,8 +44,12 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
}
}
CHECK(h_sum.CContiguous());
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
}
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
@@ -64,7 +69,7 @@ inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda_impl
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out) {
out->SetDevice(ctx->gpu_id);
out->Reshape(n_targets);
@@ -72,7 +77,7 @@ void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
gpair.SetDevice(ctx->gpu_id);
auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets);
ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView())
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
}
} // namespace tree

View File

@@ -16,6 +16,7 @@
#include "../common/common.h" // AssertGPUSupport
#include "xgboost/base.h" // GradientPair
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/linalg.h" // TensorView
@@ -30,7 +31,7 @@ XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) {
/**
* @brief Fit a tree stump as an estimation of base_score.
*/
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
bst_target_t n_targets, linalg::Vector<float>* out);
} // namespace tree
} // namespace xgboost