Support vertical federated learning with gpu_hist (#9539)
This commit is contained in:
parent
9bab06cbca
commit
c928dd4ff5
@ -26,7 +26,6 @@ namespace collective {
|
|||||||
* applied there, with the results broadcast to other workers.
|
* applied there, with the results broadcast to other workers.
|
||||||
*
|
*
|
||||||
* @tparam Function The function used to calculate the results.
|
* @tparam Function The function used to calculate the results.
|
||||||
* @tparam Args Arguments to the function.
|
|
||||||
* @param info MetaInfo about the DMatrix.
|
* @param info MetaInfo about the DMatrix.
|
||||||
* @param buffer The buffer storing the results.
|
* @param buffer The buffer storing the results.
|
||||||
* @param size The size of the buffer.
|
* @param size The size of the buffer.
|
||||||
@ -57,6 +56,52 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Apply the given function where the labels are.
|
||||||
|
*
|
||||||
|
* Normally all the workers have access to the labels, so the function is just applied locally. In
|
||||||
|
* vertical federated learning, we assume labels are only available on worker 0, so the function is
|
||||||
|
* applied there, with the results broadcast to other workers.
|
||||||
|
*
|
||||||
|
* @tparam T Type of the HostDeviceVector storing the results.
|
||||||
|
* @tparam Function The function used to calculate the results.
|
||||||
|
* @param info MetaInfo about the DMatrix.
|
||||||
|
* @param result The HostDeviceVector storing the results.
|
||||||
|
* @param function The function used to calculate the results.
|
||||||
|
*/
|
||||||
|
template <typename T, typename Function>
|
||||||
|
void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
|
||||||
|
if (info.IsVerticalFederated()) {
|
||||||
|
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||||
|
// broadcast to other workers.
|
||||||
|
std::string message;
|
||||||
|
if (collective::GetRank() == 0) {
|
||||||
|
try {
|
||||||
|
std::forward<Function>(function)();
|
||||||
|
} catch (dmlc::Error& e) {
|
||||||
|
message = e.what();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
collective::Broadcast(&message, 0);
|
||||||
|
if (!message.empty()) {
|
||||||
|
LOG(FATAL) << &message[0];
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t size{};
|
||||||
|
if (collective::GetRank() == 0) {
|
||||||
|
size = result->Size();
|
||||||
|
}
|
||||||
|
collective::Broadcast(&size, sizeof(std::size_t), 0);
|
||||||
|
|
||||||
|
result->Resize(size);
|
||||||
|
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
|
||||||
|
} else {
|
||||||
|
std::forward<Function>(function)();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Find the global max of the given value across all workers.
|
* @brief Find the global max of the given value across all workers.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -847,8 +847,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
|
|
||||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
||||||
base_score->Reshape(1);
|
base_score->Reshape(1);
|
||||||
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
|
collective::ApplyWithLabels(info, base_score->Data(),
|
||||||
sizeof(bst_float) * base_score->Size(),
|
|
||||||
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
|
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1467,8 +1466,7 @@ class LearnerImpl : public LearnerIO {
|
|||||||
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
|
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
|
||||||
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
|
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
|
||||||
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
|
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
|
||||||
collective::ApplyWithLabels(info, out_gpair->Data()->HostPointer(),
|
collective::ApplyWithLabels(info, out_gpair->Data(),
|
||||||
out_gpair->Size() * sizeof(GradientPair),
|
|
||||||
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
|
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
#include <cstdint> // std::int32_t
|
#include <cstdint> // std::int32_t
|
||||||
#include <cub/cub.cuh> // NOLINT
|
#include <cub/cub.cuh> // NOLINT
|
||||||
|
|
||||||
|
#include "../collective/aggregator.h"
|
||||||
#include "../common/cuda_context.cuh" // CUDAContext
|
#include "../common/cuda_context.cuh" // CUDAContext
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../common/stats.cuh"
|
#include "../common/stats.cuh"
|
||||||
@ -154,15 +155,15 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree);
|
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
HostDeviceVector<float> quantiles;
|
|
||||||
predt.SetDevice(ctx->Device());
|
predt.SetDevice(ctx->Device());
|
||||||
|
|
||||||
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
|
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
|
||||||
predt.Size() / info.num_row_);
|
predt.Size() / info.num_row_);
|
||||||
CHECK_LT(group_idx, d_predt.Shape(1));
|
CHECK_LT(group_idx, d_predt.Shape(1));
|
||||||
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
|
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
|
||||||
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
|
|
||||||
|
|
||||||
|
HostDeviceVector<float> quantiles;
|
||||||
|
collective::ApplyWithLabels(info, &quantiles, [&] {
|
||||||
|
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
|
||||||
auto d_row_index = dh::ToSpan(ridx);
|
auto d_row_index = dh::ToSpan(ridx);
|
||||||
auto seg_beg = nptr.DevicePointer();
|
auto seg_beg = nptr.DevicePointer();
|
||||||
auto seg_end = seg_beg + nptr.Size();
|
auto seg_end = seg_beg + nptr.Size();
|
||||||
@ -181,11 +182,12 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
|||||||
info.weights_.SetDevice(ctx->Device());
|
info.weights_.SetDevice(ctx->Device());
|
||||||
auto d_weights = info.weights_.ConstDeviceSpan();
|
auto d_weights = info.weights_.ConstDeviceSpan();
|
||||||
CHECK_EQ(d_weights.size(), d_row_index.size());
|
CHECK_EQ(d_weights.size(), d_row_index.size());
|
||||||
auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
|
auto w_it =
|
||||||
|
thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
|
||||||
common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it,
|
common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it,
|
||||||
w_it + d_weights.size(), &quantiles);
|
w_it + d_weights.size(), &quantiles);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree);
|
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree);
|
||||||
}
|
}
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|||||||
@ -55,11 +55,11 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
|||||||
} // namespace cpu_impl
|
} // namespace cpu_impl
|
||||||
|
|
||||||
namespace cuda_impl {
|
namespace cuda_impl {
|
||||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
void FitStump(Context const* ctx, MetaInfo const& info,
|
||||||
linalg::VectorView<float> out);
|
linalg::TensorView<GradientPair const, 2> gpair, linalg::VectorView<float> out);
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
|
inline void FitStump(Context const*, MetaInfo const&, linalg::TensorView<GradientPair const, 2>,
|
||||||
linalg::VectorView<float>) {
|
linalg::VectorView<float>) {
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
}
|
}
|
||||||
@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientP
|
|||||||
gpair.SetDevice(ctx->Device());
|
gpair.SetDevice(ctx->Device());
|
||||||
auto gpair_t = gpair.View(ctx->Device());
|
auto gpair_t = gpair.View(ctx->Device());
|
||||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
|
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
|
||||||
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->Device()));
|
: cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device()));
|
||||||
}
|
}
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -11,6 +11,7 @@
|
|||||||
|
|
||||||
#include <cstddef> // std::size_t
|
#include <cstddef> // std::size_t
|
||||||
|
|
||||||
|
#include "../collective/aggregator.cuh"
|
||||||
#include "../collective/communicator-inl.cuh"
|
#include "../collective/communicator-inl.cuh"
|
||||||
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
|
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
|
||||||
#include "fit_stump.h"
|
#include "fit_stump.h"
|
||||||
@ -23,8 +24,8 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
namespace cuda_impl {
|
namespace cuda_impl {
|
||||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
void FitStump(Context const* ctx, MetaInfo const& info,
|
||||||
linalg::VectorView<float> out) {
|
linalg::TensorView<GradientPair const, 2> gpair, linalg::VectorView<float> out) {
|
||||||
auto n_targets = out.Size();
|
auto n_targets = out.Size();
|
||||||
CHECK_EQ(n_targets, gpair.Shape(1));
|
CHECK_EQ(n_targets, gpair.Shape(1));
|
||||||
linalg::Vector<GradientPairPrecise> sum = linalg::Constant(ctx, GradientPairPrecise{}, n_targets);
|
linalg::Vector<GradientPairPrecise> sum = linalg::Constant(ctx, GradientPairPrecise{}, n_targets);
|
||||||
@ -49,8 +50,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
|
|||||||
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
|
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
|
||||||
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
|
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
|
||||||
|
|
||||||
collective::AllReduce<collective::Operation::kSum>(
|
collective::GlobalSum(info, ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()),
|
||||||
ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
|
d_sum.Size() * 2);
|
||||||
|
|
||||||
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
|
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
|
||||||
[=] XGBOOST_DEVICE(std::size_t i) mutable {
|
[=] XGBOOST_DEVICE(std::size_t i) mutable {
|
||||||
|
|||||||
@ -15,9 +15,11 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace {
|
namespace {
|
||||||
auto MakeModel(std::string tree_method, std::string objective, std::shared_ptr<DMatrix> dmat) {
|
auto MakeModel(std::string tree_method, std::string device, std::string objective,
|
||||||
|
std::shared_ptr<DMatrix> dmat) {
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
||||||
learner->SetParam("tree_method", tree_method);
|
learner->SetParam("tree_method", tree_method);
|
||||||
|
learner->SetParam("device", device);
|
||||||
learner->SetParam("objective", objective);
|
learner->SetParam("objective", objective);
|
||||||
if (objective.find("quantile") != std::string::npos) {
|
if (objective.find("quantile") != std::string::npos) {
|
||||||
learner->SetParam("quantile_alpha", "0.5");
|
learner->SetParam("quantile_alpha", "0.5");
|
||||||
@ -35,7 +37,7 @@ auto MakeModel(std::string tree_method, std::string objective, std::shared_ptr<D
|
|||||||
}
|
}
|
||||||
|
|
||||||
void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model,
|
void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model,
|
||||||
std::string tree_method, std::string objective) {
|
std::string tree_method, std::string device, std::string objective) {
|
||||||
auto const world_size = collective::GetWorldSize();
|
auto const world_size = collective::GetWorldSize();
|
||||||
auto const rank = collective::GetRank();
|
auto const rank = collective::GetRank();
|
||||||
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
||||||
@ -61,14 +63,14 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e
|
|||||||
}
|
}
|
||||||
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
||||||
|
|
||||||
auto model = MakeModel(tree_method, objective, sliced);
|
auto model = MakeModel(tree_method, device, objective, sliced);
|
||||||
auto base_score = GetBaseScore(model);
|
auto base_score = GetBaseScore(model);
|
||||||
ASSERT_EQ(base_score, expected_base_score);
|
ASSERT_EQ(base_score, expected_base_score) << " rank " << rank;
|
||||||
ASSERT_EQ(model, expected_model);
|
ASSERT_EQ(model, expected_model) << " rank " << rank;
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
||||||
std::unique_ptr<ServerForTest> server_;
|
std::unique_ptr<ServerForTest> server_;
|
||||||
static int constexpr kWorldSize{3};
|
static int constexpr kWorldSize{3};
|
||||||
|
|
||||||
@ -76,7 +78,7 @@ class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
|||||||
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
||||||
void TearDown() override { server_.reset(nullptr); }
|
void TearDown() override { server_.reset(nullptr); }
|
||||||
|
|
||||||
void Run(std::string tree_method, std::string objective) {
|
void Run(std::string tree_method, std::string device, std::string objective) {
|
||||||
static auto constexpr kRows{16};
|
static auto constexpr kRows{16};
|
||||||
static auto constexpr kCols{16};
|
static auto constexpr kCols{16};
|
||||||
|
|
||||||
@ -99,27 +101,35 @@ class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto model = MakeModel(tree_method, objective, dmat);
|
auto model = MakeModel(tree_method, device, objective, dmat);
|
||||||
auto score = GetBaseScore(model);
|
auto score = GetBaseScore(model);
|
||||||
|
|
||||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols,
|
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols,
|
||||||
score, model, tree_method, objective);
|
score, model, tree_method, device, objective);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_P(FederatedLearnerTest, Approx) {
|
TEST_P(VerticalFederatedLearnerTest, Approx) {
|
||||||
std::string objective = GetParam();
|
std::string objective = GetParam();
|
||||||
this->Run("approx", objective);
|
this->Run("approx", "cpu", objective);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(FederatedLearnerTest, Hist) {
|
TEST_P(VerticalFederatedLearnerTest, Hist) {
|
||||||
std::string objective = GetParam();
|
std::string objective = GetParam();
|
||||||
this->Run("hist", objective);
|
this->Run("hist", "cpu", objective);
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(FederatedLearnerObjective, FederatedLearnerTest,
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
TEST_P(VerticalFederatedLearnerTest, GPUHist) {
|
||||||
|
std::string objective = GetParam();
|
||||||
|
this->Run("hist", "cuda:0", objective);
|
||||||
|
}
|
||||||
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
FederatedLearnerObjective, VerticalFederatedLearnerTest,
|
||||||
::testing::ValuesIn(MakeObjNamesForTest()),
|
::testing::ValuesIn(MakeObjNamesForTest()),
|
||||||
[](const ::testing::TestParamInfo<FederatedLearnerTest::ParamType> &info) {
|
[](const ::testing::TestParamInfo<VerticalFederatedLearnerTest::ParamType> &info) {
|
||||||
return ObjTestNameGenerator(info);
|
return ObjTestNameGenerator(info);
|
||||||
});
|
});
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user