Fix for issue 3306. (#3324)

This commit is contained in:
Thejaswi 2018-05-23 07:12:20 +05:30 committed by Rory Mitchell
parent 8f6aadd4b7
commit d367e4fc6b
3 changed files with 6 additions and 11 deletions

View File

@ -1004,7 +1004,7 @@ class AllReducer {
template <typename T, typename FunctionT> template <typename T, typename FunctionT>
void ExecuteShards(std::vector<T> *shards, FunctionT f) { void ExecuteShards(std::vector<T> *shards, FunctionT f) {
#pragma omp parallel for schedule(static, 1) #pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) { for (int shard = 0; shard < shards->size(); ++shard) {
f(shards->at(shard)); f(shards->at(shard));
} }
@ -1023,7 +1023,7 @@ void ExecuteShards(std::vector<T> *shards, FunctionT f) {
template <typename T, typename FunctionT> template <typename T, typename FunctionT>
void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) { void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
#pragma omp parallel for schedule(static, 1) #pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) { for (int shard = 0; shard < shards->size(); ++shard) {
f(shard, shards->at(shard)); f(shard, shards->at(shard));
} }
@ -1045,7 +1045,7 @@ void ExecuteIndexShards(std::vector<T> *shards, FunctionT f) {
template <typename ReduceT,typename T, typename FunctionT> template <typename ReduceT,typename T, typename FunctionT>
ReduceT ReduceShards(std::vector<T> *shards, FunctionT f) { ReduceT ReduceShards(std::vector<T> *shards, FunctionT f) {
std::vector<ReduceT> sums(shards->size()); std::vector<ReduceT> sums(shards->size());
#pragma omp parallel for schedule(static, 1) #pragma omp parallel for schedule(static, 1) if (shards->size() > 1)
for (int shard = 0; shard < shards->size(); ++shard) { for (int shard = 0; shard < shards->size(); ++shard) {
sums[shard] = f(shards->at(shard)); sums[shard] = f(shards->at(shard));
} }

View File

@ -55,9 +55,7 @@ class RegLossObj : public ObjFunction {
avx::Float8 scale(param_.scale_pos_weight); avx::Float8 scale(param_.scale_pos_weight);
const omp_ulong remainder = n % 8; const omp_ulong remainder = n % 8;
int nthread = omp_get_max_threads(); #pragma omp parallel for schedule(static)
// Use a maximum of 8 threads
#pragma omp parallel for schedule(static) num_threads(std::min(8, nthread))
for (omp_ulong i = 0; i < n - remainder; i += 8) { for (omp_ulong i = 0; i < n - remainder; i += 8) {
avx::Float8 y(&info.labels_[i]); avx::Float8 y(&info.labels_[i]);
avx::Float8 p = Loss::PredTransform(avx::Float8(&preds_h[i])); avx::Float8 p = Loss::PredTransform(avx::Float8(&preds_h[i]));
@ -77,9 +75,6 @@ class RegLossObj : public ObjFunction {
gpair[i] = GradientPair(Loss::FirstOrderGradient(p, y) * w, gpair[i] = GradientPair(Loss::FirstOrderGradient(p, y) * w,
Loss::SecondOrderGradient(p, y) * w); Loss::SecondOrderGradient(p, y) * w);
} }
// Reset omp max threads
omp_set_num_threads(nthread);
} }
const char *DefaultEvalMetric() const override { const char *DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric(); return Loss::DefaultEvalMetric();

View File

@ -136,7 +136,7 @@ class GPURegLossObj : public ObjFunction {
} }
// run the kernel // run the kernel
#pragma omp parallel for schedule(static, 1) #pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1)
for (int i = 0; i < devices_.Size(); ++i) { for (int i = 0; i < devices_.Size(); ++i) {
int d = devices_[i]; int d = devices_[i];
dh::safe_cuda(cudaSetDevice(d)); dh::safe_cuda(cudaSetDevice(d));
@ -173,7 +173,7 @@ class GPURegLossObj : public ObjFunction {
} }
void PredTransformDevice(HostDeviceVector<float>* preds) { void PredTransformDevice(HostDeviceVector<float>* preds) {
#pragma omp parallel for schedule(static, 1) #pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1)
for (int i = 0; i < devices_.Size(); ++i) { for (int i = 0; i < devices_.Size(); ++i) {
int d = devices_[i]; int d = devices_[i];
dh::safe_cuda(cudaSetDevice(d)); dh::safe_cuda(cudaSetDevice(d));