Fix cache with gc (#8851)

- Make DMatrixCache thread-safe.
- Remove the use of thread-local memory.
This commit is contained in:
Jiaming Yuan
2023-03-01 00:39:06 +08:00
committed by GitHub
parent d9688f93c7
commit d54ef56f6f
4 changed files with 143 additions and 65 deletions

View File

@@ -3,16 +3,18 @@
*/
#include <gtest/gtest.h>
#include <xgboost/cache.h>
#include <xgboost/data.h> // DMatrix
#include <xgboost/data.h> // for DMatrix
#include <cstddef> // std::size_t
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <thread> // for thread
#include "helpers.h" // RandomDataGenerator
#include "helpers.h" // for RandomDataGenerator
namespace xgboost {
namespace {
struct CacheForTest {
std::size_t i;
std::size_t const i;
explicit CacheForTest(std::size_t k) : i{k} {}
};
@@ -20,7 +22,7 @@ struct CacheForTest {
TEST(DMatrixCache, Basic) {
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 4;
DMatrixCache<CacheForTest> cache(kCacheSize);
DMatrixCache<CacheForTest> cache{kCacheSize};
auto add_cache = [&]() {
// Create a lambda function here, so that p_fmat gets deleted upon the
@@ -52,4 +54,63 @@ TEST(DMatrixCache, Basic) {
}
}
}
TEST(DMatrixCache, MultiThread) {
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 3;
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
auto n = std::thread::hardware_concurrency() * 128u;
CHECK_NE(n, 0);
std::vector<std::shared_ptr<CacheForTest>> results(n);
{
DMatrixCache<CacheForTest> cache{kCacheSize};
std::vector<std::thread> tasks;
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
tasks.emplace_back([&, i = tidx]() {
cache.CacheItem(p_fmat, i);
auto p_fmat_local = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
results[i] = cache.CacheItem(p_fmat_local, i);
});
}
for (auto& t : tasks) {
t.join();
}
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
ASSERT_EQ(results[tidx]->i, tidx);
}
tasks.clear();
for (std::int32_t tidx = static_cast<std::int32_t>(n - 1); tidx >= 0; --tidx) {
tasks.emplace_back([&, i = tidx]() {
cache.CacheItem(p_fmat, i);
auto p_fmat_local = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
results[i] = cache.CacheItem(p_fmat_local, i);
});
}
for (auto& t : tasks) {
t.join();
}
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
ASSERT_EQ(results[tidx]->i, tidx);
}
}
{
DMatrixCache<CacheForTest> cache{n};
std::vector<std::thread> tasks;
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
tasks.emplace_back([&, tidx]() { results[tidx] = cache.CacheItem(p_fmat, tidx); });
}
for (auto& t : tasks) {
t.join();
}
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
ASSERT_EQ(results[tidx]->i, tidx);
}
}
}
} // namespace xgboost