Improve operation efficiency for single predict (#5016)

* Improve operation efficiency for single predict
This commit is contained in:
KaiJin Ji 2019-11-10 02:01:28 +08:00 committed by Jiaming Yuan
parent 374648c21a
commit 1733c9e8f7
2 changed files with 39 additions and 35 deletions

View File

@ -408,24 +408,26 @@ class Dart : public GBTree {
constexpr int kUnroll = 8; constexpr int kUnroll = 8;
const auto nsize = static_cast<bst_omp_uint>(batch.Size()); const auto nsize = static_cast<bst_omp_uint>(batch.Size());
const bst_omp_uint rest = nsize % kUnroll; const bst_omp_uint rest = nsize % kUnroll;
#pragma omp parallel for schedule(static) if (nsize >= kUnroll) {
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { #pragma omp parallel for schedule(static)
const int tid = omp_get_thread_num(); for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
RegTree::FVec& feats = thread_temp_[tid]; const int tid = omp_get_thread_num();
int64_t ridx[kUnroll]; RegTree::FVec& feats = thread_temp_[tid];
SparsePage::Inst inst[kUnroll]; int64_t ridx[kUnroll];
for (int k = 0; k < kUnroll; ++k) { SparsePage::Inst inst[kUnroll];
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k); for (int k = 0; k < kUnroll; ++k) {
} ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
for (int k = 0; k < kUnroll; ++k) { }
inst[k] = batch[i + k]; for (int k = 0; k < kUnroll; ++k) {
} inst[k] = batch[i + k];
for (int k = 0; k < kUnroll; ++k) { }
for (int gid = 0; gid < num_group; ++gid) { for (int k = 0; k < kUnroll; ++k) {
const size_t offset = ridx[k] * num_group + gid; for (int gid = 0; gid < num_group; ++gid) {
preds[offset] += const size_t offset = ridx[k] * num_group + gid;
self->PredValue(inst[k], gid, info.GetRoot(ridx[k]), preds[offset] +=
&feats, tree_begin, tree_end); self->PredValue(inst[k], gid, info.GetRoot(ridx[k]),
&feats, tree_begin, tree_end);
}
} }
} }
} }

View File

@ -63,24 +63,26 @@ class CPUPredictor : public Predictor {
// Pull to host before entering omp block, as this is not thread safe. // Pull to host before entering omp block, as this is not thread safe.
batch.data.HostVector(); batch.data.HostVector();
batch.offset.HostVector(); batch.offset.HostVector();
if (nsize >= kUnroll) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) { for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
RegTree::FVec& feats = thread_temp[tid]; RegTree::FVec& feats = thread_temp[tid];
int64_t ridx[kUnroll]; int64_t ridx[kUnroll];
SparsePage::Inst inst[kUnroll]; SparsePage::Inst inst[kUnroll];
for (int k = 0; k < kUnroll; ++k) { for (int k = 0; k < kUnroll; ++k) {
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k); ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
} }
for (int k = 0; k < kUnroll; ++k) { for (int k = 0; k < kUnroll; ++k) {
inst[k] = batch[i + k]; inst[k] = batch[i + k];
} }
for (int k = 0; k < kUnroll; ++k) { for (int k = 0; k < kUnroll; ++k) {
for (int gid = 0; gid < num_group; ++gid) { for (int gid = 0; gid < num_group; ++gid) {
const size_t offset = ridx[k] * num_group + gid; const size_t offset = ridx[k] * num_group + gid;
preds[offset] += this->PredValue( preds[offset] += this->PredValue(
inst[k], model.trees, model.tree_info, gid, inst[k], model.trees, model.tree_info, gid,
info.GetRoot(ridx[k]), &feats, tree_begin, tree_end); info.GetRoot(ridx[k]), &feats, tree_begin, tree_end);
}
} }
} }
} }