change workflow
This commit is contained in:
parent
b22644fc10
commit
3a834c4992
@ -860,10 +860,21 @@ class LearnerConfiguration : public Learner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
||||||
|
#ifndef XGBOOST_USE_HIP
|
||||||
base_score->Reshape(1);
|
base_score->Reshape(1);
|
||||||
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
|
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
|
||||||
sizeof(bst_float) * base_score->Size(),
|
sizeof(bst_float) * base_score->Size(),
|
||||||
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
|
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
|
||||||
|
#else
|
||||||
|
if (info.IsVerticalFederated()) {
|
||||||
|
base_score->Reshape(1);
|
||||||
|
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
|
||||||
|
sizeof(bst_float) * base_score->Size(),
|
||||||
|
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
|
||||||
|
} else {
|
||||||
|
UsePtr(obj_)->InitEstimation(info, base_score);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1475,10 +1486,21 @@ class LearnerImpl : public LearnerIO {
|
|||||||
private:
|
private:
|
||||||
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
|
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info, int iteration,
|
||||||
HostDeviceVector<GradientPair>* out_gpair) {
|
HostDeviceVector<GradientPair>* out_gpair) {
|
||||||
|
#ifndef XGBOOST_USE_HIP
|
||||||
out_gpair->Resize(preds.Size());
|
out_gpair->Resize(preds.Size());
|
||||||
collective::ApplyWithLabels(info, out_gpair->HostPointer(),
|
collective::ApplyWithLabels(info, out_gpair->HostPointer(),
|
||||||
out_gpair->Size() * sizeof(GradientPair),
|
out_gpair->Size() * sizeof(GradientPair),
|
||||||
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
|
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
|
||||||
|
#else
|
||||||
|
if (info.IsVerticalFederated()) {
|
||||||
|
out_gpair->Resize(preds.Size());
|
||||||
|
collective::ApplyWithLabels(info, out_gpair->HostPointer(),
|
||||||
|
out_gpair->Size() * sizeof(GradientPair),
|
||||||
|
[&] { obj_->GetGradient(preds, info, iteration, out_gpair); });
|
||||||
|
} else {
|
||||||
|
obj_->GetGradient(preds, info, iteration, out_gpair);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief random number transformation seed. */
|
/*! \brief random number transformation seed. */
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user