Avoid omp reduction in coordinate descent and aft metrics. (#7316)

Aside from the omp issue, parameter configuration for aft metric is simplified.
This commit is contained in:
Jiaming Yuan
2021-10-17 15:55:49 +08:00
committed by GitHub
parent f56e2e9a66
commit fb1a9e6bc5
4 changed files with 112 additions and 70 deletions

View File

@@ -108,27 +108,32 @@ inline std::pair<double, double> GetGradient(int group_idx, int num_group, int f
*
* \return The gradient and diagonal Hessian entry for a given feature.
*/
inline std::pair<double, double> GetGradientParallel(int group_idx, int num_group, int fidx,
const std::vector<GradientPair> &gpair,
DMatrix *p_fmat) {
double sum_grad = 0.0, sum_hess = 0.0;
inline std::pair<double, double>
GetGradientParallel(GenericParameter const *ctx, int group_idx, int num_group,
int fidx, const std::vector<GradientPair> &gpair,
DMatrix *p_fmat) {
std::vector<double> sum_grad_tloc(ctx->Threads(), 0.0);
std::vector<double> sum_hess_tloc(ctx->Threads(), 0.0);
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
auto page = batch.GetView();
auto col = page[fidx];
const auto ndata = static_cast<bst_omp_uint>(col.size());
dmlc::OMPException exc;
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
for (bst_omp_uint j = 0; j < ndata; ++j) {
exc.Run([&]() {
const bst_float v = col[j].fvalue;
auto &p = gpair[col[j].index * num_group + group_idx];
if (p.GetHess() < 0.0f) return;
sum_grad += p.GetGrad() * v;
sum_hess += p.GetHess() * v * v;
});
}
exc.Rethrow();
common::ParallelFor(ndata, ctx->Threads(), [&](size_t j) {
const bst_float v = col[j].fvalue;
auto &p = gpair[col[j].index * num_group + group_idx];
if (p.GetHess() < 0.0f) {
return;
}
auto t_idx = omp_get_thread_num();
sum_grad_tloc[t_idx] += p.GetGrad() * v;
sum_hess_tloc[t_idx] += p.GetHess() * v * v;
});
}
double sum_grad =
std::accumulate(sum_grad_tloc.cbegin(), sum_grad_tloc.cend(), 0.0);
double sum_hess =
std::accumulate(sum_hess_tloc.cbegin(), sum_hess_tloc.cend(), 0.0);
return std::make_pair(sum_grad, sum_hess);
}

View File

@@ -80,8 +80,8 @@ class CoordinateUpdater : public LinearUpdater {
DMatrix *p_fmat, gbm::GBLinearModel *model) {
const int ngroup = model->learner_model_param->num_output_group;
bst_float &w = (*model)[fidx][group_idx];
auto gradient =
GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat);
auto gradient = GetGradientParallel(learner_param_, group_idx, ngroup, fidx,
*in_gpair, p_fmat);
auto dw = static_cast<float>(
tparam_.learning_rate *
CoordinateDelta(gradient.first, gradient.second, w, tparam_.reg_alpha_denorm,