Support vertical federated learning with gpu_hist (#9539)
This commit is contained in:
@@ -55,11 +55,11 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
} // namespace cpu_impl
|
||||
|
||||
namespace cuda_impl {
|
||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
||||
linalg::VectorView<float> out);
|
||||
void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<GradientPair const, 2> gpair, linalg::VectorView<float> out);
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
|
||||
inline void FitStump(Context const*, MetaInfo const&, linalg::TensorView<GradientPair const, 2>,
|
||||
linalg::VectorView<float>) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
@@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientP
|
||||
gpair.SetDevice(ctx->Device());
|
||||
auto gpair_t = gpair.View(ctx->Device());
|
||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
|
||||
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->Device()));
|
||||
: cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device()));
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "../collective/aggregator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
|
||||
#include "fit_stump.h"
|
||||
@@ -23,8 +24,8 @@
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace cuda_impl {
|
||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
||||
linalg::VectorView<float> out) {
|
||||
void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<GradientPair const, 2> gpair, linalg::VectorView<float> out) {
|
||||
auto n_targets = out.Size();
|
||||
CHECK_EQ(n_targets, gpair.Shape(1));
|
||||
linalg::Vector<GradientPairPrecise> sum = linalg::Constant(ctx, GradientPairPrecise{}, n_targets);
|
||||
@@ -49,8 +50,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
|
||||
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
|
||||
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
|
||||
|
||||
collective::AllReduce<collective::Operation::kSum>(
|
||||
ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
|
||||
collective::GlobalSum(info, ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()),
|
||||
d_sum.Size() * 2);
|
||||
|
||||
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
|
||||
[=] XGBOOST_DEVICE(std::size_t i) mutable {
|
||||
|
||||
Reference in New Issue
Block a user