Improve operation efficiency for single predict (#5016)
* Improve operation efficiency for single predict
This commit is contained in:
parent
374648c21a
commit
1733c9e8f7
@ -408,24 +408,26 @@ class Dart : public GBTree {
|
||||
constexpr int kUnroll = 8;
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const bst_omp_uint rest = nsize % kUnroll;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec& feats = thread_temp_[tid];
|
||||
int64_t ridx[kUnroll];
|
||||
SparsePage::Inst inst[kUnroll];
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
inst[k] = batch[i + k];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] +=
|
||||
self->PredValue(inst[k], gid, info.GetRoot(ridx[k]),
|
||||
&feats, tree_begin, tree_end);
|
||||
if (nsize >= kUnroll) {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec& feats = thread_temp_[tid];
|
||||
int64_t ridx[kUnroll];
|
||||
SparsePage::Inst inst[kUnroll];
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
inst[k] = batch[i + k];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] +=
|
||||
self->PredValue(inst[k], gid, info.GetRoot(ridx[k]),
|
||||
&feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -63,24 +63,26 @@ class CPUPredictor : public Predictor {
|
||||
// Pull to host before entering omp block, as this is not thread safe.
|
||||
batch.data.HostVector();
|
||||
batch.offset.HostVector();
|
||||
if (nsize >= kUnroll) {
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec& feats = thread_temp[tid];
|
||||
int64_t ridx[kUnroll];
|
||||
SparsePage::Inst inst[kUnroll];
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
inst[k] = batch[i + k];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] += this->PredValue(
|
||||
inst[k], model.trees, model.tree_info, gid,
|
||||
info.GetRoot(ridx[k]), &feats, tree_begin, tree_end);
|
||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
||||
const int tid = omp_get_thread_num();
|
||||
RegTree::FVec& feats = thread_temp[tid];
|
||||
int64_t ridx[kUnroll];
|
||||
SparsePage::Inst inst[kUnroll];
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
inst[k] = batch[i + k];
|
||||
}
|
||||
for (int k = 0; k < kUnroll; ++k) {
|
||||
for (int gid = 0; gid < num_group; ++gid) {
|
||||
const size_t offset = ridx[k] * num_group + gid;
|
||||
preds[offset] += this->PredValue(
|
||||
inst[k], model.trees, model.tree_info, gid,
|
||||
info.GetRoot(ridx[k]), &feats, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user