Support vertical federated learning with gpu_hist (#9539)

This commit is contained in:
Rong Ou
2023-09-02 20:37:11 -07:00
committed by GitHub
parent 9bab06cbca
commit c928dd4ff5
6 changed files with 113 additions and 57 deletions

View File

@@ -55,11 +55,11 @@ void FitStump(Context const* ctx, MetaInfo const& info,
} // namespace cpu_impl
namespace cuda_impl {
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
linalg::VectorView<float> out);
void FitStump(Context const* ctx, MetaInfo const& info,
linalg::TensorView<GradientPair const, 2> gpair, linalg::VectorView<float> out);
#if !defined(XGBOOST_USE_CUDA)
inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
inline void FitStump(Context const*, MetaInfo const&, linalg::TensorView<GradientPair const, 2>,
linalg::VectorView<float>) {
common::AssertGPUSupport();
}
@@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientP
gpair.SetDevice(ctx->Device());
auto gpair_t = gpair.View(ctx->Device());
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->Device()));
: cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device()));
}
} // namespace tree
} // namespace xgboost

View File

@@ -11,6 +11,7 @@
#include <cstddef> // std::size_t
#include "../collective/aggregator.cuh"
#include "../collective/communicator-inl.cuh"
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#include "fit_stump.h"
@@ -23,8 +24,8 @@
namespace xgboost {
namespace tree {
namespace cuda_impl {
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
linalg::VectorView<float> out) {
void FitStump(Context const* ctx, MetaInfo const& info,
linalg::TensorView<GradientPair const, 2> gpair, linalg::VectorView<float> out) {
auto n_targets = out.Size();
CHECK_EQ(n_targets, gpair.Shape(1));
linalg::Vector<GradientPairPrecise> sum = linalg::Constant(ctx, GradientPairPrecise{}, n_targets);
@@ -49,8 +50,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
collective::AllReduce<collective::Operation::kSum>(
ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
collective::GlobalSum(info, ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()),
d_sum.Size() * 2);
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
[=] XGBOOST_DEVICE(std::size_t i) mutable {