Thread-safe prediction by making the prediction cache thread-local. (#5853)
Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include "helpers.h"
|
||||
#include <dmlc/filesystem.h>
|
||||
|
||||
@@ -176,6 +177,48 @@ TEST(Learner, JsonModelIO) {
|
||||
}
|
||||
}
|
||||
|
||||
// Crashes the test runner if there are race condiditions.
|
||||
//
|
||||
// Build with additional cmake flags to enable thread sanitizer
|
||||
// which definitely catches problems. Note that OpenMP needs to be
|
||||
// disabled, otherwise thread sanitizer will also report false
|
||||
// positives.
|
||||
//
|
||||
// ```
|
||||
// -DUSE_SANITIZER=ON -DENABLED_SANITIZERS=thread -DUSE_OPENMP=OFF
|
||||
// ```
|
||||
TEST(Learner, MultiThreadedPredict) {
|
||||
size_t constexpr kRows = 1000;
|
||||
size_t constexpr kCols = 1000;
|
||||
|
||||
std::shared_ptr<DMatrix> p_dmat{
|
||||
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
|
||||
p_dmat->Info().labels_.Resize(kRows);
|
||||
CHECK_NE(p_dmat->Info().num_col_, 0);
|
||||
|
||||
std::shared_ptr<DMatrix> p_data{
|
||||
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
|
||||
CHECK_NE(p_data->Info().num_col_, 0);
|
||||
|
||||
std::shared_ptr<Learner> learner{Learner::Create({p_dmat})};
|
||||
learner->Configure();
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
for (uint32_t thread_id = 0;
|
||||
thread_id < 2 * std::thread::hardware_concurrency(); ++thread_id) {
|
||||
threads.emplace_back([learner, p_data] {
|
||||
size_t constexpr kIters = 10;
|
||||
auto &entry = learner->GetThreadLocal().prediction_entry;
|
||||
for (size_t iter = 0; iter < kIters; ++iter) {
|
||||
learner->Predict(p_data, false, &entry.predictions);
|
||||
}
|
||||
});
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Learner, BinaryModelIO) {
|
||||
size_t constexpr kRows = 8;
|
||||
int32_t constexpr kIters = 4;
|
||||
|
||||
Reference in New Issue
Block a user