Support vertical federated learning with gpu_hist (#9539)
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cub/cub.cuh> // NOLINT
|
||||
|
||||
#include "../collective/aggregator.h"
|
||||
#include "../common/cuda_context.cuh" // CUDAContext
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/stats.cuh"
|
||||
@@ -154,38 +155,39 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
||||
UpdateLeafValues(&quantiles, nidx.ConstHostVector(), info, learning_rate, p_tree);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> quantiles;
|
||||
predt.SetDevice(ctx->Device());
|
||||
|
||||
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
|
||||
predt.Size() / info.num_row_);
|
||||
CHECK_LT(group_idx, d_predt.Shape(1));
|
||||
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
|
||||
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
|
||||
|
||||
auto d_row_index = dh::ToSpan(ridx);
|
||||
auto seg_beg = nptr.DevicePointer();
|
||||
auto seg_end = seg_beg + nptr.Size();
|
||||
auto val_beg = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(size_t i) {
|
||||
float p = t_predt(d_row_index[i]);
|
||||
auto y = d_labels(d_row_index[i]);
|
||||
return y - p;
|
||||
});
|
||||
CHECK_EQ(d_labels.Shape(0), position.size());
|
||||
auto val_end = val_beg + d_labels.Shape(0);
|
||||
CHECK_EQ(nidx.Size() + 1, nptr.Size());
|
||||
if (info.weights_.Empty()) {
|
||||
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
|
||||
} else {
|
||||
info.weights_.SetDevice(ctx->Device());
|
||||
auto d_weights = info.weights_.ConstDeviceSpan();
|
||||
CHECK_EQ(d_weights.size(), d_row_index.size());
|
||||
auto w_it = thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
|
||||
common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it,
|
||||
w_it + d_weights.size(), &quantiles);
|
||||
}
|
||||
|
||||
HostDeviceVector<float> quantiles;
|
||||
collective::ApplyWithLabels(info, &quantiles, [&] {
|
||||
auto d_labels = info.labels.View(ctx->Device()).Slice(linalg::All(), IdxY(info, group_idx));
|
||||
auto d_row_index = dh::ToSpan(ridx);
|
||||
auto seg_beg = nptr.DevicePointer();
|
||||
auto seg_end = seg_beg + nptr.Size();
|
||||
auto val_beg = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(size_t i) {
|
||||
float p = t_predt(d_row_index[i]);
|
||||
auto y = d_labels(d_row_index[i]);
|
||||
return y - p;
|
||||
});
|
||||
CHECK_EQ(d_labels.Shape(0), position.size());
|
||||
auto val_end = val_beg + d_labels.Shape(0);
|
||||
CHECK_EQ(nidx.Size() + 1, nptr.Size());
|
||||
if (info.weights_.Empty()) {
|
||||
common::SegmentedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, &quantiles);
|
||||
} else {
|
||||
info.weights_.SetDevice(ctx->Device());
|
||||
auto d_weights = info.weights_.ConstDeviceSpan();
|
||||
CHECK_EQ(d_weights.size(), d_row_index.size());
|
||||
auto w_it =
|
||||
thrust::make_permutation_iterator(dh::tcbegin(d_weights), dh::tcbegin(d_row_index));
|
||||
common::SegmentedWeightedQuantile(ctx, alpha, seg_beg, seg_end, val_beg, val_end, w_it,
|
||||
w_it + d_weights.size(), &quantiles);
|
||||
}
|
||||
});
|
||||
UpdateLeafValues(&quantiles.HostVector(), nidx.ConstHostVector(), info, learning_rate, p_tree);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
Reference in New Issue
Block a user