Support vertical federated learning with gpu_hist (#9539)

This commit is contained in:
Rong Ou
2023-09-02 20:37:11 -07:00
committed by GitHub
parent 9bab06cbca
commit c928dd4ff5
6 changed files with 113 additions and 57 deletions

View File

@@ -6,6 +6,7 @@
#include <cstdint> // std::int32_t
#include <cub/cub.cuh> // NOLINT
#include "../collective/aggregator.h"
#include "../common/cuda_context.cuh" // CUDAContext
#include "../common/device_helpers.cuh"
#include "../common/stats.cuh"
@@ -154,38 +155,39 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree);
}
HostDeviceVector<float> quantiles;
predt.SetDevice(ctx->Device());
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
predt.Size() / info.num_row_);
CHECK_LT(group_idx, d_predt.Shape(1));
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
auto d_row_index = dh::ToSpan(ridx);
auto seg_beg = nptr.DevicePointer();
auto seg_end = seg_beg + nptr.Size();
auto val_beg = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) {
float p = t_predt(d_row_index[i]);
auto y = d_labels(d_row_index[i]);
return y - p;
});
CHECK_EQ(d_labels.Shape(0), position.size());
auto val_end = val_beg + d_labels.Shape(0);
CHECK_EQ(nidx.Size() + 1, nptr.Size());
if (info.weights_.Empty()) {
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
} else {
info.weights_.SetDevice(ctx->Device());
auto d_weights = info.weights_.ConstDeviceSpan();
CHECK_EQ(d_weights.size(), d_row_index.size());
auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it,
w_it + d_weights.size(), &quantiles);
}
HostDeviceVector<float> quantiles;
collective::ApplyWithLabels(info, &quantiles, [&] {
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
auto d_row_index = dh::ToSpan(ridx);
auto seg_beg = nptr.DevicePointer();
auto seg_end = seg_beg + nptr.Size();
auto val_beg = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) {
float p = t_predt(d_row_index[i]);
auto y = d_labels(d_row_index[i]);
return y - p;
});
CHECK_EQ(d_labels.Shape(0), position.size());
auto val_end = val_beg + d_labels.Shape(0);
CHECK_EQ(nidx.Size() + 1, nptr.Size());
if (info.weights_.Empty()) {
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
} else {
info.weights_.SetDevice(ctx->Device());
auto d_weights = info.weights_.ConstDeviceSpan();
CHECK_EQ(d_weights.size(), d_row_index.size());
auto w_it =
thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it,
w_it + d_weights.size(), &quantiles);
}
});
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree);
}
} // namespace detail