Support vertical federated learning with gpu_hist (#9539)
This commit is contained in:
parent
9bab06cbca
commit
c928dd4ff5
@ -26,7 +26,6 @@ namespace collective {
|
||||
* applied there, with the results broadcast to other workers.
|
||||
*
|
||||
* @tparam Function The function used to calculate the results.
|
||||
* @tparam Args Arguments to the function.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param buffer The buffer storing the results.
|
||||
* @param size The size of the buffer.
|
||||
@ -57,6 +56,52 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Apply the given function where the labels are.
|
||||
*
|
||||
* Normally all the workers have access to the labels, so the function is just applied locally. In
|
||||
* vertical federated learning, we assume labels are only available on worker 0, so the function is
|
||||
* applied there, with the results broadcast to other workers.
|
||||
*
|
||||
* @tparam T Type of the HostDeviceVector storing the results.
|
||||
* @tparam Function The function used to calculate the results.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param result The HostDeviceVector storing the results.
|
||||
* @param function The function used to calculate the results.
|
||||
*/
|
||||
template <typename T, typename Function>
|
||||
void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::string message;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<Function>(function)();
|
||||
} catch (dmlc::Error& e) {
|
||||
message = e.what();
|
||||
}
|
||||
}
|
||||
|
||||
collective::Broadcast(&message, 0);
|
||||
if (!message.empty()) {
|
||||
LOG(FATAL) << &message[0];
|
||||
return;
|
||||
}
|
||||
|
||||
std::size_t size{};
|
||||
if (collective::GetRank() == 0) {
|
||||
size = result->Size();
|
||||
}
|
||||
collective::Broadcast(&size, sizeof(std::size_t), 0);
|
||||
|
||||
result->Resize(size);
|
||||
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
|
||||
} else {
|
||||
std::forward<Function>(function)();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the global max of the given value across all workers.
|
||||
*
|
||||
|
||||
@ -847,8 +847,7 @@ class LearnerConfiguration : public Learner {
|
||||
|
||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
||||
base_score->Reshape(1);
|
||||
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
|
||||
sizeof(bst_float) * base_score->Size(),
|
||||
collective::ApplyWithLabels(info, base_score->Data(),
|
||||
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
|
||||
}
|
||||
};
|
||||
@ -1467,8 +1466,7 @@ class LearnerImpl : public LearnerIO {
|
||||
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
|
||||
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
|
||||
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
|
||||
collective::ApplyWithLabels(info, out_gpair->Data()->HostPointer(),
|
||||
out_gpair->Size() * sizeof(GradientPair),
|
||||
collective::ApplyWithLabels(info, out_gpair->Data(),
|
||||
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
|
||||
}
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cub/cub.cuh> // NOLINT
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/stats.cuh"
|
||||
@ -154,38 +155,39 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
||||
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> quantiles;
|
||||
predt.SetDevice(ctx->Device());
|
||||
|
||||
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
|
||||
predt.Size() / info.num_row_);
|
||||
CHECK_LT(group_idx, d_predt.Shape(1));
|
||||
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
|
||||
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
|
||||
|
||||
auto d_row_index = dh::ToSpan(ridx);
|
||||
auto seg_beg = nptr.DevicePointer();
|
||||
auto seg_end = seg_beg + nptr.Size();
|
||||
auto val_beg = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(size_t i) {
|
||||
float p = t_predt(d_row_index[i]);
|
||||
auto y = d_labels(d_row_index[i]);
|
||||
return y - p;
|
||||
});
|
||||
CHECK_EQ(d_labels.Shape(0), position.size());
|
||||
auto val_end = val_beg + d_labels.Shape(0);
|
||||
CHECK_EQ(nidx.Size() + 1, nptr.Size());
|
||||
if (info.weights_.Empty()) {
|
||||
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
|
||||
} else {
|
||||
info.weights_.SetDevice(ctx->Device());
|
||||
auto d_weights = info.weights_.ConstDeviceSpan();
|
||||
CHECK_EQ(d_weights.size(), d_row_index.size());
|
||||
auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
|
||||
common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it,
|
||||
w_it + d_weights.size(), &quantiles);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> quantiles;
|
||||
collective::ApplyWithLabels(info, &quantiles, [&] {
|
||||
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
|
||||
auto d_row_index = dh::ToSpan(ridx);
|
||||
auto seg_beg = nptr.DevicePointer();
|
||||
auto seg_end = seg_beg + nptr.Size();
|
||||
auto val_beg = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(size_t i) {
|
||||
float p = t_predt(d_row_index[i]);
|
||||
auto y = d_labels(d_row_index[i]);
|
||||
return y - p;
|
||||
});
|
||||
CHECK_EQ(d_labels.Shape(0), position.size());
|
||||
auto val_end = val_beg + d_labels.Shape(0);
|
||||
CHECK_EQ(nidx.Size() + 1, nptr.Size());
|
||||
if (info.weights_.Empty()) {
|
||||
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
|
||||
} else {
|
||||
info.weights_.SetDevice(ctx->Device());
|
||||
auto d_weights = info.weights_.ConstDeviceSpan();
|
||||
CHECK_EQ(d_weights.size(), d_row_index.size());
|
||||
auto w_it =
|
||||
thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
|
||||
common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it,
|
||||
w_it + d_weights.size(), &quantiles);
|
||||
}
|
||||
});
|
||||
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
@ -55,11 +55,11 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
} // namespace cpu_impl
|
||||
|
||||
namespace cuda_impl {
|
||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
||||
linalg::VectorView<float> out);
|
||||
void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<GradientPair const, 2> gpair, linalg::VectorView<float> out);
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
|
||||
inline void FitStump(Context const*, MetaInfo const&, linalg::TensorView<GradientPair const, 2>,
|
||||
linalg::VectorView<float>) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::Matrix<GradientP
|
||||
gpair.SetDevice(ctx->Device());
|
||||
auto gpair_t = gpair.View(ctx->Device());
|
||||
ctx->IsCPU() ? cpu_impl::FitStump(ctx, info, gpair_t, out->HostView())
|
||||
: cuda_impl::FitStump(ctx, gpair_t, out->View(ctx->Device()));
|
||||
: cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device()));
|
||||
}
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "../collective/aggregator.cuh"
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
|
||||
#include "fit_stump.h"
|
||||
@ -23,8 +24,8 @@
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
namespace cuda_impl {
|
||||
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
|
||||
linalg::VectorView<float> out) {
|
||||
void FitStump(Context const* ctx, MetaInfo const& info,
|
||||
linalg::TensorView<GradientPair const, 2> gpair, linalg::VectorView<float> out) {
|
||||
auto n_targets = out.Size();
|
||||
CHECK_EQ(n_targets, gpair.Shape(1));
|
||||
linalg::Vector<GradientPairPrecise> sum = linalg::Constant(ctx, GradientPairPrecise{}, n_targets);
|
||||
@ -49,8 +50,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
|
||||
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
|
||||
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));
|
||||
|
||||
collective::AllReduce<collective::Operation::kSum>(
|
||||
ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
|
||||
collective::GlobalSum(info, ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()),
|
||||
d_sum.Size() * 2);
|
||||
|
||||
thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
|
||||
[=] XGBOOST_DEVICE(std::size_t i) mutable {
|
||||
|
||||
@ -15,9 +15,11 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace {
|
||||
auto MakeModel(std::string tree_method, std::string objective, std::shared_ptr<DMatrix> dmat) {
|
||||
auto MakeModel(std::string tree_method, std::string device, std::string objective,
|
||||
std::shared_ptr<DMatrix> dmat) {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
||||
learner->SetParam("tree_method", tree_method);
|
||||
learner->SetParam("device", device);
|
||||
learner->SetParam("objective", objective);
|
||||
if (objective.find("quantile") != std::string::npos) {
|
||||
learner->SetParam("quantile_alpha", "0.5");
|
||||
@ -35,7 +37,7 @@ auto MakeModel(std::string tree_method, std::string objective, std::shared_ptr<D
|
||||
}
|
||||
|
||||
void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json expected_model,
|
||||
std::string tree_method, std::string objective) {
|
||||
std::string tree_method, std::string device, std::string objective) {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};
|
||||
@ -61,14 +63,14 @@ void VerifyObjective(size_t rows, size_t cols, float expected_base_score, Json e
|
||||
}
|
||||
std::shared_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
||||
|
||||
auto model = MakeModel(tree_method, objective, sliced);
|
||||
auto model = MakeModel(tree_method, device, objective, sliced);
|
||||
auto base_score = GetBaseScore(model);
|
||||
ASSERT_EQ(base_score, expected_base_score);
|
||||
ASSERT_EQ(model, expected_model);
|
||||
ASSERT_EQ(base_score, expected_base_score) << " rank " << rank;
|
||||
ASSERT_EQ(model, expected_model) << " rank " << rank;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
||||
class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
||||
std::unique_ptr<ServerForTest> server_;
|
||||
static int constexpr kWorldSize{3};
|
||||
|
||||
@ -76,7 +78,7 @@ class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
||||
void SetUp() override { server_ = std::make_unique<ServerForTest>(kWorldSize); }
|
||||
void TearDown() override { server_.reset(nullptr); }
|
||||
|
||||
void Run(std::string tree_method, std::string objective) {
|
||||
void Run(std::string tree_method, std::string device, std::string objective) {
|
||||
static auto constexpr kRows{16};
|
||||
static auto constexpr kCols{16};
|
||||
|
||||
@ -99,27 +101,35 @@ class FederatedLearnerTest : public ::testing::TestWithParam<std::string> {
|
||||
}
|
||||
}
|
||||
|
||||
auto model = MakeModel(tree_method, objective, dmat);
|
||||
auto model = MakeModel(tree_method, device, objective, dmat);
|
||||
auto score = GetBaseScore(model);
|
||||
|
||||
RunWithFederatedCommunicator(kWorldSize, server_->Address(), &VerifyObjective, kRows, kCols,
|
||||
score, model, tree_method, objective);
|
||||
score, model, tree_method, device, objective);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(FederatedLearnerTest, Approx) {
|
||||
TEST_P(VerticalFederatedLearnerTest, Approx) {
|
||||
std::string objective = GetParam();
|
||||
this->Run("approx", objective);
|
||||
this->Run("approx", "cpu", objective);
|
||||
}
|
||||
|
||||
TEST_P(FederatedLearnerTest, Hist) {
|
||||
TEST_P(VerticalFederatedLearnerTest, Hist) {
|
||||
std::string objective = GetParam();
|
||||
this->Run("hist", objective);
|
||||
this->Run("hist", "cpu", objective);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(FederatedLearnerObjective, FederatedLearnerTest,
|
||||
::testing::ValuesIn(MakeObjNamesForTest()),
|
||||
[](const ::testing::TestParamInfo<FederatedLearnerTest::ParamType> &info) {
|
||||
return ObjTestNameGenerator(info);
|
||||
});
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST_P(VerticalFederatedLearnerTest, GPUHist) {
|
||||
std::string objective = GetParam();
|
||||
this->Run("hist", "cuda:0", objective);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
FederatedLearnerObjective, VerticalFederatedLearnerTest,
|
||||
::testing::ValuesIn(MakeObjNamesForTest()),
|
||||
[](const ::testing::TestParamInfo<VerticalFederatedLearnerTest::ParamType> &info) {
|
||||
return ObjTestNameGenerator(info);
|
||||
});
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user