restore learner
This commit is contained in:
parent
6762230d9a
commit
6bbca9a8b7
@ -846,20 +846,9 @@ class LearnerConfiguration : public Learner {
|
||||
}
|
||||
|
||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
|
||||
#ifndef XGBOOST_USE_HIP
|
||||
base_score->Reshape(1);
|
||||
collective::ApplyWithLabels(info, base_score->Data(),
|
||||
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
|
||||
#else
|
||||
if (info.IsVerticalFederated()) {
|
||||
base_score->Reshape(1);
|
||||
collective::ApplyWithLabels(info, base_score->Data()->HostPointer(),
|
||||
sizeof(bst_float) * base_score->Size(),
|
||||
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
|
||||
} else {
|
||||
UsePtr(obj_)->InitEstimation(info, base_score);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -1478,20 +1467,9 @@ class LearnerImpl : public LearnerIO {
|
||||
private:
|
||||
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
|
||||
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
|
||||
#ifndef XGBOOST_USE_HIP
|
||||
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
|
||||
collective::ApplyWithLabels(info, out_gpair->Data(),
|
||||
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
|
||||
#else
|
||||
if (info.IsVerticalFederated()) {
|
||||
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
|
||||
collective::ApplyWithLabels(info, out_gpair->Data(),
|
||||
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
|
||||
}
|
||||
else {
|
||||
obj_->GetGradient(preds, info, iter, out_gpair);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/*! \brief random number transformation seed. */
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user