Support vertical federated learning with gpu_hist (#9539)

This commit is contained in:
Rong Ou 2023-09-02 20:37:11 -07:00 committed by GitHub
parent 9bab06cbca
commit c928dd4ff5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 57 deletions

View File

@ -26,7 +26,6 @@ namespace collective {
* applied there, with the results broadcast to other workers. * applied there, with the results broadcast to other workers.
* *
* @tparam Function The function used to calculate the results. * @tparam Function The function used to calculate the results.
* @tparam Args Arguments to the function.
* @param info MetaInfo about the DMatrix. * @param info MetaInfo about the DMatrix.
* @param buffer The buffer storing the results. * @param buffer The buffer storing the results.
* @param size The size of the buffer. * @param size The size of the buffer.
@ -57,6 +56,52 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
} }
} }
/**
* @brief Apply the given function where the labels are.
*
* Normally all the workers have access to the labels, so the function is just applied locally. In
* vertical federated learning, we assume labels are only available on worker 0, so the function is
* applied there, with the results broadcast to other workers.
*
* @tparam T Type of the HostDeviceVector storing the results.
* @tparam Function The function used to calculate the results.
* @param info MetaInfo about the DMatrix.
* @param result The HostDeviceVector storing the results.
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
}
collective::Broadcast(&message, 0);
if (!message.empty()) {
LOG(FATAL) << &message[0];
return;
}
std::size_t size{};
if (collective::GetRank() == 0) {
size = result->Size();
}
collective::Broadcast(&size, sizeof(std::size_t), 0);
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
} else {
std::forward<Function>(function)();
}
}
/** /**
* @brief Find the global max of the given value across all workers. * @brief Find the global max of the given value across all workers.
* *

View File

@ -847,8 +847,7 @@ class LearnerConfiguration : public Learner {
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) { void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
base_score->Reshape(1); base_score->Reshape(1);
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(), collective::ApplyWithLabels(info, base_score->Data(),
sizeof(bst_float) * base_score->Size(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); }); [&] { UsePtr(obj_)->InitEstimation(info, base_score); });
} }
}; };
@ -1467,8 +1466,7 @@ class LearnerImpl : public LearnerIO {
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) { std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength()); out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
collective::ApplyWithLabels(info, out_gpair->Data()->HostPointer(), collective::ApplyWithLabels(info, out_gpair->Data(),
out_gpair->Size() * sizeof(GradientPair),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); }); [&] { obj_->GetGradient(preds, info, iter, out_gpair); });
} }

View File

@ -6,6 +6,7 @@
#include <cstdint> // std::int32_t #include <cstdint> // std::int32_t
#include <cub/cub.cuh> // NOLINT #include <cub/cub.cuh> // NOLINT
#include "../collective/aggregator.h"
#include "../common/cuda_context.cuh" // CUDAContext #include "../common/cuda_context.cuh" // CUDAContext
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/stats.cuh" #include "../common/stats.cuh"
@ -154,15 +155,15 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree); UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree);
} }
HostDeviceVector<float> quantiles;
predt.SetDevice(ctx->Device()); predt.SetDevice(ctx->Device());
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_, auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
predt.Size() / info.num_row_); predt.Size() / info.num_row_);
CHECK_LT(group_idx, d_predt.Shape(1)); CHECK_LT(group_idx, d_predt.Shape(1));
auto t_predt = d_predt.Slice(linalg::All(), group_idx); auto t_predt = d_predt.Slice(linalg::All(), group_idx);
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
HostDeviceVector<float> quantiles;
collective::ApplyWithLabels(info, &quantiles, [&] {
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
auto d_row_index = dh::ToSpan(ridx); auto d_row_index = dh::ToSpan(ridx);
auto seg_beg = nptr.DevicePointer(); auto seg_beg = nptr.DevicePointer();
auto seg_end = seg_beg + nptr.Size(); auto seg_end = seg_beg + nptr.Size();
@ -181,11 +182,12 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
info.weights_.SetDevice(ctx->Device()); info.weights_.SetDevice(ctx->Device());
auto d_weights = info.weights_.ConstDeviceSpan(); auto d_weights = info.weights_.ConstDeviceSpan();
CHECK_EQ(d_weights.size(), d_row_index.size()); CHECK_EQ(d_weights.size(), d_row_index.size());
auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index)); auto w_it =
thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it, common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it,
w_it + d_weights.size(), &quantiles); w_it + d_weights.size(), &quantiles);
} }
});
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree); UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree);
} }
} // namespace detail } // namespace detail

View File

@ -55,11 +55,11 @@ void FitStump(Context const* ctx, MetaInfo const& info,
} // namespace cpu_impl } // namespace cpu_impl
namespace cuda_impl { namespace cuda_impl {
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair, void FitStump(Context const* ctx, MetaInfo const& info,
linalg::VectorView<float> out); linalg::TensorView<GradientPair const, 2> gpair, linalg::VectorView<float> out);
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>, inline void FitStump(Context const*, MetaInfo const&, linalg::TensorView<GradientPair const, 2>,
linalg::VectorView<float>) { linalg::VectorView<float>) {
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientP
gpair.SetDevice(ctx->Device()); gpair.SetDevice(ctx->Device());
auto gpair_t = gpair.View(ctx->Device()); auto gpair_t = gpair.View(ctx->Device());
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView()) ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->Device())); : cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device()));
} }
} // namespace tree } // namespace tree
} // namespace xgboost } // namespace xgboost

View File

@ -11,6 +11,7 @@
#include <cstddef> // std::size_t #include <cstddef> // std::size_t
#include "../collective/aggregator.cuh"
#include "../collective/communicator-inl.cuh" #include "../collective/communicator-inl.cuh"
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator #include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#include "fit_stump.h" #include "fit_stump.h"
@ -23,8 +24,8 @@
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
namespace cuda_impl { namespace cuda_impl {
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair, void FitStump(Context const* ctx, MetaInfo const& info,
linalg::VectorView<float> out) { linalg::TensorView<GradientPair const, 2> gpair, linalg::VectorView<float> out) {
auto n_targets = out.Size(); auto n_targets = out.Size();
CHECK_EQ(n_targets, gpair.Shape(1)); CHECK_EQ(n_targets, gpair.Shape(1));
linalg::Vector<GradientPairPrecise> sum = linalg::Constant(ctx, GradientPairPrecise{}, n_targets); linalg::Vector<GradientPairPrecise> sum = linalg::Constant(ctx, GradientPairPrecise{}, n_targets);
@ -49,8 +50,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it, thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values())); thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
collective::AllReduce<collective::Operation::kSum>( collective::GlobalSum(info, ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()),
ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2); d_sum.Size() * 2);
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets, thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
[=] XGBOOST_DEVICE(std::size_t i) mutable { [=] XGBOOST_DEVICE(std::size_t i) mutable {

View File

@ -15,9 +15,11 @@
namespace xgboost { namespace xgboost {
namespace { namespace {
auto MakeModel(std::string tree_method, std::string objective, std::shared_ptr<DMatrix> dmat) { auto MakeModel(std::string tree_method, std::string device, std::string objective,
std::shared_ptr<DMatrix> dmat) {
std::unique_ptr<Learner> learner{Learner::Create({dmat})}; std::unique_ptr<Learner> learner{Learner::Create({dmat})};
learner->SetParam("tree_method", tree_method); learner->SetParam("tree_method", tree_method);
learner->SetParam("device", device);
learner->SetParam("objective", objective); learner->SetParam("objective", objective);
if (objective.find("quantile") != std::string::npos) { if (objective.find("quantile") != std::string::npos) {
learner->SetParam("quantile_alpha", "0.5"); learner->SetParam("quantile_alpha", "0.5");
@ -35,7 +37,7 @@ auto MakeModel(std::string tree_method, std::string objective, std::shared_ptr<D
} }
void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model, void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model,
std::string tree_method, std::string objective) { std::string tree_method, std::string device, std::string objective) {
auto const world_size = collective::GetWorldSize(); auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank(); auto const rank = collective::GetRank();
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)}; std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
@ -61,14 +63,14 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e
} }
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)}; std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
auto model = MakeModel(tree_method, objective, sliced); auto model = MakeModel(tree_method, device, objective, sliced);
auto base_score = GetBaseScore(model); auto base_score = GetBaseScore(model);
ASSERT_EQ(base_score, expected_base_score); ASSERT_EQ(base_score, expected_base_score) << " rank " << rank;
ASSERT_EQ(model, expected_model); ASSERT_EQ(model, expected_model) << " rank " << rank;
} }
} // namespace } // namespace
class FederatedLearnerTest : public ::testing::TestWithParam<std::string> { class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string> {
std::unique_ptr<ServerForTest> server_; std::unique_ptr<ServerForTest> server_;
static int constexpr kWorldSize{3}; static int constexpr kWorldSize{3};
@ -76,7 +78,7 @@ class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); } void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
void TearDown() override { server_.reset(nullptr); } void TearDown() override { server_.reset(nullptr); }
void Run(std::string tree_method, std::string objective) { void Run(std::string tree_method, std::string device, std::string objective) {
static auto constexpr kRows{16}; static auto constexpr kRows{16};
static auto constexpr kCols{16}; static auto constexpr kCols{16};
@ -99,27 +101,35 @@ class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
} }
} }
auto model = MakeModel(tree_method, objective, dmat); auto model = MakeModel(tree_method, device, objective, dmat);
auto score = GetBaseScore(model); auto score = GetBaseScore(model);
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols, RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols,
score, model, tree_method, objective); score, model, tree_method, device, objective);
} }
}; };
TEST_P(FederatedLearnerTest, Approx) { TEST_P(VerticalFederatedLearnerTest, Approx) {
std::string objective = GetParam(); std::string objective = GetParam();
this->Run("approx", objective); this->Run("approx", "cpu", objective);
} }
TEST_P(FederatedLearnerTest, Hist) { TEST_P(VerticalFederatedLearnerTest, Hist) {
std::string objective = GetParam(); std::string objective = GetParam();
this->Run("hist", objective); this->Run("hist", "cpu", objective);
} }
INSTANTIATE_TEST_SUITE_P(FederatedLearnerObjective, FederatedLearnerTest, #if defined(XGBOOST_USE_CUDA)
TEST_P(VerticalFederatedLearnerTest, GPUHist) {
std::string objective = GetParam();
this->Run("hist", "cuda:0", objective);
}
#endif // defined(XGBOOST_USE_CUDA)
INSTANTIATE_TEST_SUITE_P(
FederatedLearnerObjective, VerticalFederatedLearnerTest,
::testing::ValuesIn(MakeObjNamesForTest()), ::testing::ValuesIn(MakeObjNamesForTest()),
[](const ::testing::TestParamInfo<FederatedLearnerTest::ParamType> &info) { [](const ::testing::TestParamInfo<VerticalFederatedLearnerTest::ParamType> &info) {
return ObjTestNameGenerator(info); return ObjTestNameGenerator(info);
}); });
} // namespace xgboost } // namespace xgboost