More tests for column split and vertical federated learning (#8985)
Added some more tests for the learner and fit_stump, for both column-wise distributed learning and vertical federated learning. Also moved the `IsRowSplit` and `IsColumnSplit` methods from the `DMatrix` to the `MetaInfo` since in some places we only have access to the `MetaInfo`. Added a new convenience method `IsVerticalFederatedLearning`. Some refactoring of the testing fixtures.
This commit is contained in:
@@ -860,9 +860,9 @@ class LearnerConfiguration : public Learner {
|
||||
|
||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
||||
// Special handling for vertical federated learning.
|
||||
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the estimation is calculated there
|
||||
// and added to other workers.
|
||||
// and broadcast to other workers.
|
||||
if (collective::GetRank() == 0) {
|
||||
UsePtr(obj_)->InitEstimation(info, base_score);
|
||||
collective::Broadcast(base_score->Data()->HostPointer(),
|
||||
@@ -1487,7 +1487,7 @@ class LearnerImpl : public LearnerIO {
|
||||
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
|
||||
HostDeviceVector<GradientPair>* out_gpair) {
|
||||
// Special handling for vertical federated learning.
|
||||
if (collective::IsFederated() && info.data_split_mode == DataSplitMode::kCol) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the gradients are calculated there
|
||||
// and broadcast to other workers.
|
||||
if (collective::GetRank() == 0) {
|
||||
|
||||
Reference in New Issue
Block a user