diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index e840643ff..ddab3710d 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -408,24 +408,26 @@ class Dart : public GBTree { constexpr int kUnroll = 8; const auto nsize = static_cast(batch.Size()); const bst_omp_uint rest = nsize % kUnroll; - #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { - const int tid = omp_get_thread_num(); - RegTree::FVec& feats = thread_temp_[tid]; - int64_t ridx[kUnroll]; - SparsePage::Inst inst[kUnroll]; - for (int k = 0; k < kUnroll; ++k) { - ridx[k] = static_cast(batch.base_rowid + i + k); - } - for (int k = 0; k < kUnroll; ++k) { - inst[k] = batch[i + k]; - } - for (int k = 0; k < kUnroll; ++k) { - for (int gid = 0; gid < num_group; ++gid) { - const size_t offset = ridx[k] * num_group + gid; - preds[offset] += - self->PredValue(inst[k], gid, info.GetRoot(ridx[k]), - &feats, tree_begin, tree_end); + if (nsize >= kUnroll) { + #pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { + const int tid = omp_get_thread_num(); + RegTree::FVec& feats = thread_temp_[tid]; + int64_t ridx[kUnroll]; + SparsePage::Inst inst[kUnroll]; + for (int k = 0; k < kUnroll; ++k) { + ridx[k] = static_cast(batch.base_rowid + i + k); + } + for (int k = 0; k < kUnroll; ++k) { + inst[k] = batch[i + k]; + } + for (int k = 0; k < kUnroll; ++k) { + for (int gid = 0; gid < num_group; ++gid) { + const size_t offset = ridx[k] * num_group + gid; + preds[offset] += + self->PredValue(inst[k], gid, info.GetRoot(ridx[k]), + &feats, tree_begin, tree_end); + } } } } diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 4a1dccb8b..302c89994 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -63,24 +63,26 @@ class CPUPredictor : public Predictor { // Pull to host before entering omp block, as this is not thread safe. batch.data.HostVector(); batch.offset.HostVector(); + if (nsize >= kUnroll) { #pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { - const int tid = omp_get_thread_num(); - RegTree::FVec& feats = thread_temp[tid]; - int64_t ridx[kUnroll]; - SparsePage::Inst inst[kUnroll]; - for (int k = 0; k < kUnroll; ++k) { - ridx[k] = static_cast(batch.base_rowid + i + k); - } - for (int k = 0; k < kUnroll; ++k) { - inst[k] = batch[i + k]; - } - for (int k = 0; k < kUnroll; ++k) { - for (int gid = 0; gid < num_group; ++gid) { - const size_t offset = ridx[k] * num_group + gid; - preds[offset] += this->PredValue( - inst[k], model.trees, model.tree_info, gid, - info.GetRoot(ridx[k]), &feats, tree_begin, tree_end); + for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { + const int tid = omp_get_thread_num(); + RegTree::FVec& feats = thread_temp[tid]; + int64_t ridx[kUnroll]; + SparsePage::Inst inst[kUnroll]; + for (int k = 0; k < kUnroll; ++k) { + ridx[k] = static_cast(batch.base_rowid + i + k); + } + for (int k = 0; k < kUnroll; ++k) { + inst[k] = batch[i + k]; + } + for (int k = 0; k < kUnroll; ++k) { + for (int gid = 0; gid < num_group; ++gid) { + const size_t offset = ridx[k] * num_group + gid; + preds[offset] += this->PredValue( + inst[k], model.trees, model.tree_info, gid, + info.GetRoot(ridx[k]), &feats, tree_begin, tree_end); + } } } }