change workflow

This commit is contained in:
amdsc21 2023-05-20 07:04:06 +02:00
parent b22644fc10
commit 3a834c4992

View File

@ -860,10 +860,21 @@ class LearnerConfiguration : public Learner {
}
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
#ifndef XGBOOST_USE_HIP
base_score->Reshape(1);
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
#else
if (info.IsVerticalFederated()) {
base_score->Reshape(1);
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
sizeof(bst_float) * base_score->Size(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
} else {
UsePtr(obj_)->InitEstimation(info, base_score);
}
#endif
}
};
@ -1475,10 +1486,21 @@ class LearnerImpl : public LearnerIO {
private:
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
HostDeviceVector<GradientPair>* out_gpair) {
#ifndef XGBOOST_USE_HIP
out_gpair->Resize(preds.Size());
collective::ApplyWithLabels(info, out_gpair->HostPointer(),
out_gpair->Size() * sizeof(GradientPair),
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
#else
if (info.IsVerticalFederated()) {
out_gpair->Resize(preds.Size());
collective::ApplyWithLabels(info, out_gpair->HostPointer(),
out_gpair->Size() * sizeof(GradientPair),
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
} else {
obj_->GetGradient(preds, info, iteration, out_gpair);
}
#endif
}
/*! \brief random number transformation seed. */