Improve operation efficiency for single predict (#5016)
* Improve operation efficiency for single predict
This commit is contained in:
parent
374648c21a
commit
1733c9e8f7
@ -408,6 +408,7 @@ class Dart : public GBTree {
|
|||||||
constexpr int kUnroll = 8;
|
constexpr int kUnroll = 8;
|
||||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||||
const bst_omp_uint rest = nsize % kUnroll;
|
const bst_omp_uint rest = nsize % kUnroll;
|
||||||
|
if (nsize >= kUnroll) {
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
@ -429,6 +430,7 @@ class Dart : public GBTree {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
||||||
RegTree::FVec& feats = thread_temp_[0];
|
RegTree::FVec& feats = thread_temp_[0];
|
||||||
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||||
|
|||||||
@ -63,6 +63,7 @@ class CPUPredictor : public Predictor {
|
|||||||
// Pull to host before entering omp block, as this is not thread safe.
|
// Pull to host before entering omp block, as this is not thread safe.
|
||||||
batch.data.HostVector();
|
batch.data.HostVector();
|
||||||
batch.offset.HostVector();
|
batch.offset.HostVector();
|
||||||
|
if (nsize >= kUnroll) {
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
@ -84,6 +85,7 @@ class CPUPredictor : public Predictor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
||||||
RegTree::FVec& feats = thread_temp[0];
|
RegTree::FVec& feats = thread_temp[0];
|
||||||
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user