Improve OpenMP exception handling (#6680)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include "./param.h"
|
||||
#include "../gbm/gblinear_model.h"
|
||||
#include "../common/random.h"
|
||||
#include "../common/threading_utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linear {
|
||||
@@ -115,14 +116,18 @@ inline std::pair<double, double> GetGradientParallel(int group_idx, int num_grou
|
||||
auto page = batch.GetView();
|
||||
auto col = page[fidx];
|
||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
const bst_float v = col[j].fvalue;
|
||||
auto &p = gpair[col[j].index * num_group + group_idx];
|
||||
if (p.GetHess() < 0.0f) continue;
|
||||
sum_grad += p.GetGrad() * v;
|
||||
sum_hess += p.GetHess() * v * v;
|
||||
exc.Run([&]() {
|
||||
const bst_float v = col[j].fvalue;
|
||||
auto &p = gpair[col[j].index * num_group + group_idx];
|
||||
if (p.GetHess() < 0.0f) return;
|
||||
sum_grad += p.GetGrad() * v;
|
||||
sum_hess += p.GetHess() * v * v;
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
}
|
||||
return std::make_pair(sum_grad, sum_hess);
|
||||
}
|
||||
@@ -142,14 +147,18 @@ inline std::pair<double, double> GetBiasGradientParallel(int group_idx, int num_
|
||||
DMatrix *p_fmat) {
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
auto &p = gpair[i * num_group + group_idx];
|
||||
if (p.GetHess() >= 0.0f) {
|
||||
sum_grad += p.GetGrad();
|
||||
sum_hess += p.GetHess();
|
||||
}
|
||||
exc.Run([&]() {
|
||||
auto &p = gpair[i * num_group + group_idx];
|
||||
if (p.GetHess() >= 0.0f) {
|
||||
sum_grad += p.GetGrad();
|
||||
sum_hess += p.GetHess();
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
return std::make_pair(sum_grad, sum_hess);
|
||||
}
|
||||
|
||||
@@ -172,12 +181,16 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group,
|
||||
auto col = page[fidx];
|
||||
// update grad value
|
||||
const auto num_row = static_cast<bst_omp_uint>(col.size());
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint j = 0; j < num_row; ++j) {
|
||||
GradientPair &p = (*in_gpair)[col[j].index * num_group + group_idx];
|
||||
if (p.GetHess() < 0.0f) continue;
|
||||
p += GradientPair(p.GetHess() * col[j].fvalue * dw, 0);
|
||||
exc.Run([&]() {
|
||||
GradientPair &p = (*in_gpair)[col[j].index * num_group + group_idx];
|
||||
if (p.GetHess() < 0.0f) return;
|
||||
p += GradientPair(p.GetHess() * col[j].fvalue * dw, 0);
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -195,12 +208,16 @@ inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias
|
||||
DMatrix *p_fmat) {
|
||||
if (dbias == 0.0f) return;
|
||||
const auto ndata = static_cast<bst_omp_uint>(p_fmat->Info().num_row_);
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
GradientPair &g = (*in_gpair)[i * num_group + group_idx];
|
||||
if (g.GetHess() < 0.0f) continue;
|
||||
g += GradientPair(g.GetHess() * dbias, 0);
|
||||
exc.Run([&]() {
|
||||
GradientPair &g = (*in_gpair)[i * num_group + group_idx];
|
||||
if (g.GetHess() < 0.0f) return;
|
||||
g += GradientPair(g.GetHess() * dbias, 0);
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -336,10 +353,9 @@ class GreedyFeatureSelector : public FeatureSelector {
|
||||
const bst_omp_uint nfeat = model.learner_model_param->num_feature;
|
||||
// Calculate univariate gradient sums
|
||||
std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.));
|
||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||
auto page = batch.GetView();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||
auto page = batch.GetView();
|
||||
common::ParallelFor(nfeat, [&](bst_omp_uint i) {
|
||||
const auto col = page[i];
|
||||
const bst_uint ndata = col.size();
|
||||
auto &sums = gpair_sums_[group_idx * nfeat + i];
|
||||
@@ -350,7 +366,7 @@ class GreedyFeatureSelector : public FeatureSelector {
|
||||
sums.first += p.GetGrad() * v;
|
||||
sums.second += p.GetHess() * v * v;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
// Find a feature with the largest magnitude of weight change
|
||||
int best_fidx = 0;
|
||||
@@ -405,8 +421,7 @@ class ThriftyFeatureSelector : public FeatureSelector {
|
||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||
auto page = batch.GetView();
|
||||
// column-parallel is usually fastaer than row-parallel
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||
common::ParallelFor(nfeat, [&](bst_omp_uint i) {
|
||||
const auto col = page[i];
|
||||
const bst_uint ndata = col.size();
|
||||
for (bst_uint gid = 0u; gid < ngroup; ++gid) {
|
||||
@@ -419,7 +434,7 @@ class ThriftyFeatureSelector : public FeatureSelector {
|
||||
sums.second += p.GetHess() * v * v;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
// rank by descending weight magnitude within the groups
|
||||
std::fill(deltaw_.begin(), deltaw_.end(), 0.f);
|
||||
|
||||
@@ -54,38 +54,42 @@ class ShotgunUpdater : public LinearUpdater {
|
||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||
auto page = batch.GetView();
|
||||
const auto nfeat = static_cast<bst_omp_uint>(batch.Size());
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nfeat; ++i) {
|
||||
int ii = selector_->NextFeature
|
||||
(i, *model, 0, in_gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm,
|
||||
param_.reg_lambda_denorm);
|
||||
if (ii < 0) continue;
|
||||
const bst_uint fid = ii;
|
||||
auto col = page[ii];
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
for (auto& c : col) {
|
||||
const GradientPair &p = gpair[c.index * ngroup + gid];
|
||||
if (p.GetHess() < 0.0f) continue;
|
||||
const bst_float v = c.fvalue;
|
||||
sum_grad += p.GetGrad() * v;
|
||||
sum_hess += p.GetHess() * v * v;
|
||||
exc.Run([&]() {
|
||||
int ii = selector_->NextFeature
|
||||
(i, *model, 0, in_gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm,
|
||||
param_.reg_lambda_denorm);
|
||||
if (ii < 0) return;
|
||||
const bst_uint fid = ii;
|
||||
auto col = page[ii];
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
double sum_grad = 0.0, sum_hess = 0.0;
|
||||
for (auto& c : col) {
|
||||
const GradientPair &p = gpair[c.index * ngroup + gid];
|
||||
if (p.GetHess() < 0.0f) continue;
|
||||
const bst_float v = c.fvalue;
|
||||
sum_grad += p.GetGrad() * v;
|
||||
sum_hess += p.GetHess() * v * v;
|
||||
}
|
||||
bst_float &w = (*model)[fid][gid];
|
||||
auto dw = static_cast<bst_float>(
|
||||
param_.learning_rate *
|
||||
CoordinateDelta(sum_grad, sum_hess, w, param_.reg_alpha_denorm,
|
||||
param_.reg_lambda_denorm));
|
||||
if (dw == 0.f) continue;
|
||||
w += dw;
|
||||
// update grad values
|
||||
for (auto& c : col) {
|
||||
GradientPair &p = gpair[c.index * ngroup + gid];
|
||||
if (p.GetHess() < 0.0f) continue;
|
||||
p += GradientPair(p.GetHess() * c.fvalue * dw, 0);
|
||||
}
|
||||
}
|
||||
bst_float &w = (*model)[fid][gid];
|
||||
auto dw = static_cast<bst_float>(
|
||||
param_.learning_rate *
|
||||
CoordinateDelta(sum_grad, sum_hess, w, param_.reg_alpha_denorm,
|
||||
param_.reg_lambda_denorm));
|
||||
if (dw == 0.f) continue;
|
||||
w += dw;
|
||||
// update grad values
|
||||
for (auto& c : col) {
|
||||
GradientPair &p = gpair[c.index * ngroup + gid];
|
||||
if (p.GetHess() < 0.0f) continue;
|
||||
p += GradientPair(p.GetHess() * c.fvalue * dw, 0);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
exc.Rethrow();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user