Support vertical federated learning (#8932)
This commit is contained in:
@@ -21,7 +21,8 @@
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace cpu_impl {
|
||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
||||
void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<GradientPair const, 2> gpair,
|
||||
linalg::VectorView<float> out) {
|
||||
auto n_targets = out.Size();
|
||||
CHECK_EQ(n_targets, gpair.Shape(1));
|
||||
@@ -43,8 +44,12 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
|
||||
}
|
||||
}
|
||||
CHECK(h_sum.CContiguous());
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||
|
||||
// In vertical federated learning, only worker 0 needs to call this, no need to do an allreduce.
|
||||
if (!collective::IsFederated() || info.data_split_mode != DataSplitMode::kCol) {
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||
}
|
||||
|
||||
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
|
||||
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
|
||||
@@ -64,7 +69,7 @@ inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace cuda_impl
|
||||
|
||||
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
|
||||
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
|
||||
bst_target_t n_targets, linalg::Vector<float>* out) {
|
||||
out->SetDevice(ctx->gpu_id);
|
||||
out->Reshape(n_targets);
|
||||
@@ -72,7 +77,7 @@ void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
|
||||
|
||||
gpair.SetDevice(ctx->gpu_id);
|
||||
auto gpair_t = linalg::MakeTensorView(ctx, &gpair, n_samples, n_targets);
|
||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, gpair_t, out->HostView())
|
||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
|
||||
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->gpu_id));
|
||||
}
|
||||
} // namespace tree
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "../common/common.h" // AssertGPUSupport
|
||||
#include "xgboost/base.h" // GradientPair
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/data.h" // MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/linalg.h" // TensorView
|
||||
|
||||
@@ -30,7 +31,7 @@ XGBOOST_DEVICE inline double CalcUnregularizedWeight(T sum_grad, T sum_hess) {
|
||||
/**
|
||||
* @brief Fit a tree stump as an estimation of base_score.
|
||||
*/
|
||||
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,
|
||||
void FitStump(Context const* ctx, MetaInfo const& info, HostDeviceVector<GradientPair> const& gpair,
|
||||
bst_target_t n_targets, linalg::Vector<float>* out);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user