diff --git a/src/learner.cc b/src/learner.cc index 78297404b..7df450811 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -860,10 +860,21 @@ class LearnerConfiguration : public Learner { } void InitEstimation(MetaInfo const& info, linalg::Tensor* base_score) { +#ifndef XGBOOST_USE_HIP base_score->Reshape(1); collective::ApplyWithLabels(info, base_score->Data()->HostPointer(), sizeof(bst_float) * base_score->Size(), [&] { UsePtr(obj_)->InitEstimation(info, base_score); }); +#else + if (info.IsVerticalFederated()) { + base_score->Reshape(1); + collective::ApplyWithLabels(info, base_score->Data()->HostPointer(), + sizeof(bst_float) * base_score->Size(), + [&] { UsePtr(obj_)->InitEstimation(info, base_score); }); + } else { + UsePtr(obj_)->InitEstimation(info, base_score); + } +#endif } }; @@ -1475,10 +1486,21 @@ class LearnerImpl : public LearnerIO { private: void GetGradient(HostDeviceVector const& preds, MetaInfo const& info, int iteration, HostDeviceVector* out_gpair) { +#ifndef XGBOOST_USE_HIP out_gpair->Resize(preds.Size()); collective::ApplyWithLabels(info, out_gpair->HostPointer(), out_gpair->Size() * sizeof(GradientPair), [&] { obj_->GetGradient(preds, info, iteration, out_gpair); }); +#else + if (info.IsVerticalFederated()) { + out_gpair->Resize(preds.Size()); + collective::ApplyWithLabels(info, out_gpair->HostPointer(), + out_gpair->Size() * sizeof(GradientPair), + [&] { obj_->GetGradient(preds, info, iteration, out_gpair); }); + } else { + obj_->GetGradient(preds, info, iteration, out_gpair); + } +#endif } /*! \brief random number transformation seed. */