From c928dd4ff500743fdf0ffd335464ccbf360f285d Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sat, 2 Sep 2023 20:37:11 -0700 Subject: [PATCH] Support vertical federated learning with `gpu_hist` (#9539) --- src/collective/aggregator.h | 47 ++++++++++++++++++- src/learner.cc | 6 +-- src/objective/adaptive.cu | 54 +++++++++++----------- src/tree/fit_stump.cc | 8 ++-- src/tree/fit_stump.cu | 9 ++-- tests/cpp/plugin/test_federated_learner.cc | 46 ++++++++++-------- 6 files changed, 113 insertions(+), 57 deletions(-) diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index b33ca28ef..f2a9ff528 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -26,7 +26,6 @@ namespace collective { * applied there, with the results broadcast to other workers. * * @tparam Function The function used to calculate the results. - * @tparam Args Arguments to the function. * @param info MetaInfo about the DMatrix. * @param buffer The buffer storing the results. * @param size The size of the buffer. @@ -57,6 +56,52 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& } } +/** + * @brief Apply the given function where the labels are. + * + * Normally all the workers have access to the labels, so the function is just applied locally. In + * vertical federated learning, we assume labels are only available on worker 0, so the function is + * applied there, with the results broadcast to other workers. + * + * @tparam T Type of the HostDeviceVector storing the results. + * @tparam Function The function used to calculate the results. + * @param info MetaInfo about the DMatrix. + * @param result The HostDeviceVector storing the results. + * @param function The function used to calculate the results. + */ +template +void ApplyWithLabels(MetaInfo const& info, HostDeviceVector* result, Function&& function) { + if (info.IsVerticalFederated()) { + // We assume labels are only available on worker 0, so the calculation is done there and result + // broadcast to other workers. + std::string message; + if (collective::GetRank() == 0) { + try { + std::forward(function)(); + } catch (dmlc::Error& e) { + message = e.what(); + } + } + + collective::Broadcast(&message, 0); + if (!message.empty()) { + LOG(FATAL) << &message[0]; + return; + } + + std::size_t size{}; + if (collective::GetRank() == 0) { + size = result->Size(); + } + collective::Broadcast(&size, sizeof(std::size_t), 0); + + result->Resize(size); + collective::Broadcast(result->HostPointer(), size * sizeof(T), 0); + } else { + std::forward(function)(); + } +} + /** * @brief Find the global max of the given value across all workers. * diff --git a/src/learner.cc b/src/learner.cc index 33725b612..79dca44bd 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -847,8 +847,7 @@ class LearnerConfiguration : public Learner { void InitEstimation(MetaInfo const& info, linalg::Tensor* base_score) { base_score->Reshape(1); - collective::ApplyWithLabels(info, base_score->Data()->HostPointer(), - sizeof(bst_float) * base_score->Size(), + collective::ApplyWithLabels(info, base_score->Data(), [&] { UsePtr(obj_)->InitEstimation(info, base_score); }); } }; @@ -1467,8 +1466,7 @@ class LearnerImpl : public LearnerIO { void GetGradient(HostDeviceVector const& preds, MetaInfo const& info, std::int32_t iter, linalg::Matrix* out_gpair) { out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength()); - collective::ApplyWithLabels(info, out_gpair->Data()->HostPointer(), - out_gpair->Size() * sizeof(GradientPair), + collective::ApplyWithLabels(info, out_gpair->Data(), [&] { obj_->GetGradient(preds, info, iter, out_gpair); }); } diff --git a/src/objective/adaptive.cu b/src/objective/adaptive.cu index 29f70a8d8..cea211622 100644 --- a/src/objective/adaptive.cu +++ b/src/objective/adaptive.cu @@ -6,6 +6,7 @@ #include // std::int32_t #include // NOLINT +#include "../collective/aggregator.h" #include "../common/cuda_context.cuh" // CUDAContext #include "../common/device_helpers.cuh" #include "../common/stats.cuh" @@ -154,38 +155,39 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span pos UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree); } - HostDeviceVector quantiles; predt.SetDevice(ctx->Device()); - auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_, predt.Size() / info.num_row_); CHECK_LT(group_idx, d_predt.Shape(1)); auto t_predt = d_predt.Slice(linalg::All(), group_idx); - auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx)); - - auto d_row_index = dh::ToSpan(ridx); - auto seg_beg = nptr.DevicePointer(); - auto seg_end = seg_beg + nptr.Size(); - auto val_beg = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), - [=] XGBOOST_DEVICE(size_t i) { - float p = t_predt(d_row_index[i]); - auto y = d_labels(d_row_index[i]); - return y - p; - }); - CHECK_EQ(d_labels.Shape(0), position.size()); - auto val_end = val_beg + d_labels.Shape(0); - CHECK_EQ(nidx.Size() + 1, nptr.Size()); - if (info.weights_.Empty()) { - common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles); - } else { - info.weights_.SetDevice(ctx->Device()); - auto d_weights = info.weights_.ConstDeviceSpan(); - CHECK_EQ(d_weights.size(), d_row_index.size()); - auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index)); - common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it, - w_it + d_weights.size(), &quantiles); - } + HostDeviceVector quantiles; + collective::ApplyWithLabels(info, &quantiles, [&] { + auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx)); + auto d_row_index = dh::ToSpan(ridx); + auto seg_beg = nptr.DevicePointer(); + auto seg_end = seg_beg + nptr.Size(); + auto val_beg = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(size_t i) { + float p = t_predt(d_row_index[i]); + auto y = d_labels(d_row_index[i]); + return y - p; + }); + CHECK_EQ(d_labels.Shape(0), position.size()); + auto val_end = val_beg + d_labels.Shape(0); + CHECK_EQ(nidx.Size() + 1, nptr.Size()); + if (info.weights_.Empty()) { + common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles); + } else { + info.weights_.SetDevice(ctx->Device()); + auto d_weights = info.weights_.ConstDeviceSpan(); + CHECK_EQ(d_weights.size(), d_row_index.size()); + auto w_it = + thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index)); + common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it, + w_it + d_weights.size(), &quantiles); + } + }); UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree); } } // namespace detail diff --git a/src/tree/fit_stump.cc b/src/tree/fit_stump.cc index ec1b6fe18..a8f5e1d8e 100644 --- a/src/tree/fit_stump.cc +++ b/src/tree/fit_stump.cc @@ -55,11 +55,11 @@ void FitStump(Context const* ctx, MetaInfo const& info, } // namespace cpu_impl namespace cuda_impl { -void FitStump(Context const* ctx, linalg::TensorView gpair, - linalg::VectorView out); +void FitStump(Context const* ctx, MetaInfo const& info, + linalg::TensorView gpair, linalg::VectorView out); #if !defined(XGBOOST_USE_CUDA) -inline void FitStump(Context const*, linalg::TensorView, +inline void FitStump(Context const*, MetaInfo const&, linalg::TensorView, linalg::VectorView) { common::AssertGPUSupport(); } @@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::MatrixDevice()); auto gpair_t = gpair.View(ctx->Device()); ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView()) - : cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->Device())); + : cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device())); } } // namespace tree } // namespace xgboost diff --git a/src/tree/fit_stump.cu b/src/tree/fit_stump.cu index 40b2a0c96..f0d53bff1 100644 --- a/src/tree/fit_stump.cu +++ b/src/tree/fit_stump.cu @@ -11,6 +11,7 @@ #include // std::size_t +#include "../collective/aggregator.cuh" #include "../collective/communicator-inl.cuh" #include "../common/device_helpers.cuh" // dh::MakeTransformIterator #include "fit_stump.h" @@ -23,8 +24,8 @@ namespace xgboost { namespace tree { namespace cuda_impl { -void FitStump(Context const* ctx, linalg::TensorView gpair, - linalg::VectorView out) { +void FitStump(Context const* ctx, MetaInfo const& info, + linalg::TensorView gpair, linalg::VectorView out) { auto n_targets = out.Size(); CHECK_EQ(n_targets, gpair.Shape(1)); linalg::Vector sum = linalg::Constant(ctx, GradientPairPrecise{}, n_targets); @@ -49,8 +50,8 @@ void FitStump(Context const* ctx, linalg::TensorView gpai thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it, thrust::make_discard_iterator(), dh::tbegin(d_sum.Values())); - collective::AllReduce( - ctx->gpu_id, reinterpret_cast(d_sum.Values().data()), d_sum.Size() * 2); + collective::GlobalSum(info, ctx->gpu_id, reinterpret_cast(d_sum.Values().data()), + d_sum.Size() * 2); thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets, [=] XGBOOST_DEVICE(std::size_t i) mutable { diff --git a/tests/cpp/plugin/test_federated_learner.cc b/tests/cpp/plugin/test_federated_learner.cc index ac514d169..427bd790c 100644 --- a/tests/cpp/plugin/test_federated_learner.cc +++ b/tests/cpp/plugin/test_federated_learner.cc @@ -15,9 +15,11 @@ namespace xgboost { namespace { -auto MakeModel(std::string tree_method, std::string objective, std::shared_ptr dmat) { +auto MakeModel(std::string tree_method, std::string device, std::string objective, + std::shared_ptr dmat) { std::unique_ptr learner{Learner::Create({dmat})}; learner->SetParam("tree_method", tree_method); + learner->SetParam("device", device); learner->SetParam("objective", objective); if (objective.find("quantile") != std::string::npos) { learner->SetParam("quantile_alpha", "0.5"); @@ -35,7 +37,7 @@ auto MakeModel(std::string tree_method, std::string objective, std::shared_ptr dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)}; @@ -61,14 +63,14 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e } std::shared_ptr sliced{dmat->SliceCol(world_size, rank)}; - auto model = MakeModel(tree_method, objective, sliced); + auto model = MakeModel(tree_method, device, objective, sliced); auto base_score = GetBaseScore(model); - ASSERT_EQ(base_score, expected_base_score); - ASSERT_EQ(model, expected_model); + ASSERT_EQ(base_score, expected_base_score) << " rank " << rank; + ASSERT_EQ(model, expected_model) << " rank " << rank; } } // namespace -class FederatedLearnerTest : public ::testing::TestWithParam { +class VerticalFederatedLearnerTest : public ::testing::TestWithParam { std::unique_ptr server_; static int constexpr kWorldSize{3}; @@ -76,7 +78,7 @@ class FederatedLearnerTest : public ::testing::TestWithParam { void SetUp() override { server_ = std::make_unique(kWorldSize); } void TearDown() override { server_.reset(nullptr); } - void Run(std::string tree_method, std::string objective) { + void Run(std::string tree_method, std::string device, std::string objective) { static auto constexpr kRows{16}; static auto constexpr kCols{16}; @@ -99,27 +101,35 @@ class FederatedLearnerTest : public ::testing::TestWithParam { } } - auto model = MakeModel(tree_method, objective, dmat); + auto model = MakeModel(tree_method, device, objective, dmat); auto score = GetBaseScore(model); RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols, - score, model, tree_method, objective); + score, model, tree_method, device, objective); } }; -TEST_P(FederatedLearnerTest, Approx) { +TEST_P(VerticalFederatedLearnerTest, Approx) { std::string objective = GetParam(); - this->Run("approx", objective); + this->Run("approx", "cpu", objective); } -TEST_P(FederatedLearnerTest, Hist) { +TEST_P(VerticalFederatedLearnerTest, Hist) { std::string objective = GetParam(); - this->Run("hist", objective); + this->Run("hist", "cpu", objective); } -INSTANTIATE_TEST_SUITE_P(FederatedLearnerObjective, FederatedLearnerTest, - ::testing::ValuesIn(MakeObjNamesForTest()), - [](const ::testing::TestParamInfo &info) { - return ObjTestNameGenerator(info); - }); +#if defined(XGBOOST_USE_CUDA) +TEST_P(VerticalFederatedLearnerTest, GPUHist) { + std::string objective = GetParam(); + this->Run("hist", "cuda:0", objective); +} +#endif // defined(XGBOOST_USE_CUDA) + +INSTANTIATE_TEST_SUITE_P( + FederatedLearnerObjective, VerticalFederatedLearnerTest, + ::testing::ValuesIn(MakeObjNamesForTest()), + [](const ::testing::TestParamInfo &info) { + return ObjTestNameGenerator(info); + }); } // namespace xgboost