Thread-safe prediction by making the prediction cache thread-local. (#5853)

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
boxdot
2020-07-30 06:33:50 +02:00
committed by GitHub
parent fa3715f584
commit d268a2a463
5 changed files with 71 additions and 14 deletions

View File

@@ -3,6 +3,7 @@
*/
#include <gtest/gtest.h>
#include <vector>
#include <thread>
#include "helpers.h"
#include <dmlc/filesystem.h>
@@ -176,6 +177,48 @@ TEST(Learner, JsonModelIO) {
}
}
// Crashes the test runner if there are race condiditions.
//
// Build with additional cmake flags to enable thread sanitizer
// which definitely catches problems. Note that OpenMP needs to be
// disabled, otherwise thread sanitizer will also report false
// positives.
//
// ```
// -DUSE_SANITIZER=ON -DENABLED_SANITIZERS=thread -DUSE_OPENMP=OFF
// ```
TEST(Learner, MultiThreadedPredict) {
size_t constexpr kRows = 1000;
size_t constexpr kCols = 1000;
std::shared_ptr<DMatrix> p_dmat{
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
p_dmat->Info().labels_.Resize(kRows);
CHECK_NE(p_dmat->Info().num_col_, 0);
std::shared_ptr<DMatrix> p_data{
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
CHECK_NE(p_data->Info().num_col_, 0);
std::shared_ptr<Learner> learner{Learner::Create({p_dmat})};
learner->Configure();
std::vector<std::thread> threads;
for (uint32_t thread_id = 0;
thread_id < 2 * std::thread::hardware_concurrency(); ++thread_id) {
threads.emplace_back([learner, p_data] {
size_t constexpr kIters = 10;
auto &entry = learner->GetThreadLocal().prediction_entry;
for (size_t iter = 0; iter < kIters; ++iter) {
learner->Predict(p_data, false, &entry.predictions);
}
});
}
for (auto &thread : threads) {
thread.join();
}
}
TEST(Learner, BinaryModelIO) {
size_t constexpr kRows = 8;
int32_t constexpr kIters = 4;