Make prediction functions thread safe. (#6648)
This commit is contained in:
@@ -199,7 +199,7 @@ TEST(Learner, JsonModelIO) {
|
||||
// ```
|
||||
TEST(Learner, MultiThreadedPredict) {
|
||||
size_t constexpr kRows = 1000;
|
||||
size_t constexpr kCols = 1000;
|
||||
size_t constexpr kCols = 100;
|
||||
|
||||
std::shared_ptr<DMatrix> p_dmat{
|
||||
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix()};
|
||||
@@ -219,8 +219,11 @@ TEST(Learner, MultiThreadedPredict) {
|
||||
threads.emplace_back([learner, p_data] {
|
||||
size_t constexpr kIters = 10;
|
||||
auto &entry = learner->GetThreadLocal().prediction_entry;
|
||||
HostDeviceVector<float> predictions;
|
||||
for (size_t iter = 0; iter < kIters; ++iter) {
|
||||
learner->Predict(p_data, false, &entry.predictions);
|
||||
learner->Predict(p_data, false, &predictions, 0, true); // leaf
|
||||
learner->Predict(p_data, false, &predictions, 0, false, true); // contribs
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user