Make prediction functions thread safe. (#6648)

This commit is contained in:
Jiaming Yuan
2021-01-28 23:29:43 +08:00
committed by GitHub
parent 0f2ed21a9d
commit c3c8e66fc9
4 changed files with 72 additions and 70 deletions

View File

@@ -199,7 +199,7 @@ TEST(Learner, JsonModelIO) {
// ```
TEST(Learner, MultiThreadedPredict) {
size_t constexpr kRows = 1000;
size_t constexpr kCols = 1000;
size_t constexpr kCols = 100;
std::shared_ptr<DMatrix> p_dmat{
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
@@ -219,8 +219,11 @@ TEST(Learner, MultiThreadedPredict) {
threads.emplace_back([learner, p_data] {
size_t constexpr kIters = 10;
auto &entry = learner->GetThreadLocal().prediction_entry;
HostDeviceVector<float> predictions;
for (size_t iter = 0; iter < kIters; ++iter) {
learner->Predict(p_data, false, &entry.predictions);
learner->Predict(p_data, false, &predictions, 0, true); // leaf
learner->Predict(p_data, false, &predictions, 0, false, true); // contribs
}
});
}