Fix cache with gc (#8851)
- Make DMatrixCache thread-safe. - Remove the use of thread-local memory.
This commit is contained in:
parent
d9688f93c7
commit
d54ef56f6f
@ -6,16 +6,18 @@
|
|||||||
|
|
||||||
#include <xgboost/logging.h> // CHECK_EQ
|
#include <xgboost/logging.h> // CHECK_EQ
|
||||||
|
|
||||||
#include <cstddef> // std::size_t
|
#include <cstddef> // for size_t
|
||||||
#include <memory> // std::weak_ptr,std::shared_ptr,std::make_shared
|
#include <memory> // for weak_ptr, shared_ptr, make_shared
|
||||||
#include <queue> // std:queue
|
#include <mutex> // for mutex, lock_guard
|
||||||
#include <unordered_map> // std::unordered_map
|
#include <queue> // for queue
|
||||||
#include <vector> // std::vector
|
#include <thread> // for thread
|
||||||
|
#include <unordered_map> // for unordered_map
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
class DMatrix;
|
class DMatrix;
|
||||||
/**
|
/**
|
||||||
* \brief FIFO cache for DMatrix related data.
|
* \brief Thread-aware FIFO cache for DMatrix related data.
|
||||||
*
|
*
|
||||||
* \tparam CacheT The type that needs to be cached.
|
* \tparam CacheT The type that needs to be cached.
|
||||||
*/
|
*/
|
||||||
@ -34,9 +36,31 @@ class DMatrixCache {
|
|||||||
|
|
||||||
static constexpr std::size_t DefaultSize() { return 32; }
|
static constexpr std::size_t DefaultSize() { return 32; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutable std::mutex lock_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::unordered_map<DMatrix const*, Item> container_;
|
struct Key {
|
||||||
std::queue<DMatrix const*> queue_;
|
DMatrix const* ptr;
|
||||||
|
std::thread::id const thread_id;
|
||||||
|
|
||||||
|
bool operator==(Key const& that) const {
|
||||||
|
return ptr == that.ptr && thread_id == that.thread_id;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct Hash {
|
||||||
|
std::size_t operator()(Key const& key) const noexcept {
|
||||||
|
std::size_t f = std::hash<DMatrix const*>()(key.ptr);
|
||||||
|
std::size_t s = std::hash<std::thread::id>()(key.thread_id);
|
||||||
|
if (f == s) {
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
return f ^ s;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
std::unordered_map<Key, Item, Hash> container_;
|
||||||
|
std::queue<Key> queue_;
|
||||||
std::size_t max_size_;
|
std::size_t max_size_;
|
||||||
|
|
||||||
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
|
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
|
||||||
@ -44,8 +68,8 @@ class DMatrixCache {
|
|||||||
void ClearExpired() {
|
void ClearExpired() {
|
||||||
// Clear expired entries
|
// Clear expired entries
|
||||||
this->CheckConsistent();
|
this->CheckConsistent();
|
||||||
std::vector<DMatrix const*> expired;
|
std::vector<Key> expired;
|
||||||
std::queue<DMatrix const*> remained;
|
std::queue<Key> remained;
|
||||||
|
|
||||||
while (!queue_.empty()) {
|
while (!queue_.empty()) {
|
||||||
auto p_fmat = queue_.front();
|
auto p_fmat = queue_.front();
|
||||||
@ -61,8 +85,8 @@ class DMatrixCache {
|
|||||||
CHECK(queue_.empty());
|
CHECK(queue_.empty());
|
||||||
CHECK_EQ(remained.size() + expired.size(), container_.size());
|
CHECK_EQ(remained.size() + expired.size(), container_.size());
|
||||||
|
|
||||||
for (auto const* p_fmat : expired) {
|
for (auto const& key : expired) {
|
||||||
container_.erase(p_fmat);
|
container_.erase(key);
|
||||||
}
|
}
|
||||||
while (!remained.empty()) {
|
while (!remained.empty()) {
|
||||||
auto p_fmat = remained.front();
|
auto p_fmat = remained.front();
|
||||||
@ -74,7 +98,9 @@ class DMatrixCache {
|
|||||||
|
|
||||||
void ClearExcess() {
|
void ClearExcess() {
|
||||||
this->CheckConsistent();
|
this->CheckConsistent();
|
||||||
while (queue_.size() >= max_size_) {
|
// clear half of the entries to prevent repeatingly clearing cache.
|
||||||
|
std::size_t half_size = max_size_ / 2;
|
||||||
|
while (queue_.size() >= half_size && !queue_.empty()) {
|
||||||
auto p_fmat = queue_.front();
|
auto p_fmat = queue_.front();
|
||||||
queue_.pop();
|
queue_.pop();
|
||||||
container_.erase(p_fmat);
|
container_.erase(p_fmat);
|
||||||
@ -88,7 +114,7 @@ class DMatrixCache {
|
|||||||
*/
|
*/
|
||||||
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
|
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
|
||||||
/**
|
/**
|
||||||
* \brief Cache a new DMatrix if it's no in the cache already.
|
* \brief Cache a new DMatrix if it's not in the cache already.
|
||||||
*
|
*
|
||||||
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
|
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
|
||||||
* entry this shared pointer is necessary. More importantly, the life time of this
|
* entry this shared pointer is necessary. More importantly, the life time of this
|
||||||
@ -101,35 +127,42 @@ class DMatrixCache {
|
|||||||
* created.
|
* created.
|
||||||
*/
|
*/
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
std::shared_ptr<CacheT>& CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
|
std::shared_ptr<CacheT> CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
|
||||||
CHECK(m);
|
CHECK(m);
|
||||||
|
std::lock_guard<std::mutex> guard{lock_};
|
||||||
|
|
||||||
this->ClearExpired();
|
this->ClearExpired();
|
||||||
if (container_.size() >= max_size_) {
|
if (container_.size() >= max_size_) {
|
||||||
this->ClearExcess();
|
this->ClearExcess();
|
||||||
}
|
}
|
||||||
// after clear, cache size < max_size
|
// after clear, cache size < max_size
|
||||||
CHECK_LT(container_.size(), max_size_);
|
CHECK_LT(container_.size(), max_size_);
|
||||||
auto it = container_.find(m.get());
|
auto key = Key{m.get(), std::this_thread::get_id()};
|
||||||
|
auto it = container_.find(key);
|
||||||
if (it == container_.cend()) {
|
if (it == container_.cend()) {
|
||||||
// after the new DMatrix, cache size is at most max_size
|
// after the new DMatrix, cache size is at most max_size
|
||||||
container_[m.get()] = {m, std::make_shared<CacheT>(args...)};
|
container_[key] = {m, std::make_shared<CacheT>(args...)};
|
||||||
queue_.push(m.get());
|
queue_.emplace(key);
|
||||||
}
|
}
|
||||||
return container_.at(m.get()).value;
|
return container_.at(key).value;
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* \brief Get a const reference to the underlying hash map. Clear expired caches before
|
* \brief Get a const reference to the underlying hash map. Clear expired caches before
|
||||||
* returning.
|
* returning.
|
||||||
*/
|
*/
|
||||||
decltype(container_) const& Container() {
|
decltype(container_) const& Container() {
|
||||||
|
std::lock_guard<std::mutex> guard{lock_};
|
||||||
|
|
||||||
this->ClearExpired();
|
this->ClearExpired();
|
||||||
return container_;
|
return container_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
|
std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
|
||||||
CHECK(container_.find(m) != container_.cend());
|
std::lock_guard<std::mutex> guard{lock_};
|
||||||
CHECK(!container_.at(m).ref.expired());
|
auto key = Key{m, std::this_thread::get_id()};
|
||||||
return container_.at(m).value;
|
CHECK(container_.find(key) != container_.cend());
|
||||||
|
CHECK(!container_.at(key).ref.expired());
|
||||||
|
return container_.at(key).value;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -14,6 +14,8 @@
|
|||||||
#include <functional> // std::function
|
#include <functional> // std::function
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <thread> // for get_id
|
||||||
|
#include <utility> // for make_pair
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
@ -48,18 +50,17 @@ struct PredictionCacheEntry {
|
|||||||
* \brief A container for managed prediction caches.
|
* \brief A container for managed prediction caches.
|
||||||
*/
|
*/
|
||||||
class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
|
class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
|
||||||
// we cache up to 32 DMatrix
|
// We cache up to 64 DMatrix for all threads
|
||||||
std::size_t static constexpr DefaultSize() { return 32; }
|
std::size_t static constexpr DefaultSize() { return 64; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
|
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
|
||||||
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, std::int32_t device) {
|
||||||
this->CacheItem(m);
|
auto p_cache = this->CacheItem(m);
|
||||||
auto p_cache = this->container_.find(m.get());
|
|
||||||
if (device != Context::kCpuId) {
|
if (device != Context::kCpuId) {
|
||||||
p_cache->second.Value().predictions.SetDevice(device);
|
p_cache->predictions.SetDevice(device);
|
||||||
}
|
}
|
||||||
return p_cache->second.Value();
|
return *p_cache;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -328,9 +328,6 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
|
|||||||
using LearnerAPIThreadLocalStore =
|
using LearnerAPIThreadLocalStore =
|
||||||
dmlc::ThreadLocalStore<std::map<Learner const *, XGBAPIThreadLocalEntry>>;
|
dmlc::ThreadLocalStore<std::map<Learner const *, XGBAPIThreadLocalEntry>>;
|
||||||
|
|
||||||
using ThreadLocalPredictionCache =
|
|
||||||
dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>;
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
StringView ModelMsg() {
|
StringView ModelMsg() {
|
||||||
return StringView{
|
return StringView{
|
||||||
@ -368,6 +365,8 @@ class LearnerConfiguration : public Learner {
|
|||||||
LearnerModelParam learner_model_param_;
|
LearnerModelParam learner_model_param_;
|
||||||
LearnerTrainParam tparam_;
|
LearnerTrainParam tparam_;
|
||||||
// Initial prediction.
|
// Initial prediction.
|
||||||
|
PredictionContainer prediction_container_;
|
||||||
|
|
||||||
std::vector<std::string> metric_names_;
|
std::vector<std::string> metric_names_;
|
||||||
|
|
||||||
void ConfigureModelParamWithoutBaseScore() {
|
void ConfigureModelParamWithoutBaseScore() {
|
||||||
@ -426,22 +425,15 @@ class LearnerConfiguration : public Learner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache)
|
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix>> cache)
|
||||||
: need_configuration_{true} {
|
: need_configuration_{true} {
|
||||||
monitor_.Init("Learner");
|
monitor_.Init("Learner");
|
||||||
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
|
|
||||||
for (std::shared_ptr<DMatrix> const& d : cache) {
|
for (std::shared_ptr<DMatrix> const& d : cache) {
|
||||||
if (d) {
|
if (d) {
|
||||||
local_cache.Cache(d, Context::kCpuId);
|
prediction_container_.Cache(d, Context::kCpuId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
~LearnerConfiguration() override {
|
|
||||||
auto local_cache = ThreadLocalPredictionCache::Get();
|
|
||||||
if (local_cache->find(this) != local_cache->cend()) {
|
|
||||||
local_cache->erase(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Configuration before data is known.
|
// Configuration before data is known.
|
||||||
void Configure() override {
|
void Configure() override {
|
||||||
@ -499,10 +491,6 @@ class LearnerConfiguration : public Learner {
|
|||||||
CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0) << ModelNotFitted();
|
CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0) << ModelNotFitted();
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual PredictionContainer* GetPredictionCache() const {
|
|
||||||
return &((*ThreadLocalPredictionCache::Get())[this]);
|
|
||||||
}
|
|
||||||
|
|
||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
// If configuration is loaded, ensure that the model came from the same version
|
// If configuration is loaded, ensure that the model came from the same version
|
||||||
CHECK(IsA<Object>(in));
|
CHECK(IsA<Object>(in));
|
||||||
@ -741,11 +729,10 @@ class LearnerConfiguration : public Learner {
|
|||||||
if (mparam_.num_feature == 0) {
|
if (mparam_.num_feature == 0) {
|
||||||
// TODO(hcho3): Change num_feature to 64-bit integer
|
// TODO(hcho3): Change num_feature to 64-bit integer
|
||||||
unsigned num_feature = 0;
|
unsigned num_feature = 0;
|
||||||
auto local_cache = this->GetPredictionCache();
|
for (auto const& matrix : prediction_container_.Container()) {
|
||||||
for (auto& matrix : local_cache->Container()) {
|
CHECK(matrix.first.ptr);
|
||||||
CHECK(matrix.first);
|
|
||||||
CHECK(!matrix.second.ref.expired());
|
CHECK(!matrix.second.ref.expired());
|
||||||
const uint64_t num_col = matrix.first->Info().num_col_;
|
const uint64_t num_col = matrix.first.ptr->Info().num_col_;
|
||||||
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
||||||
<< "Unfortunately, XGBoost does not support data matrices with "
|
<< "Unfortunately, XGBoost does not support data matrices with "
|
||||||
<< std::numeric_limits<unsigned>::max() << " features or greater";
|
<< std::numeric_limits<unsigned>::max() << " features or greater";
|
||||||
@ -817,13 +804,13 @@ class LearnerConfiguration : public Learner {
|
|||||||
*/
|
*/
|
||||||
void ConfigureTargets() {
|
void ConfigureTargets() {
|
||||||
CHECK(this->obj_);
|
CHECK(this->obj_);
|
||||||
auto const& cache = this->GetPredictionCache()->Container();
|
auto const& cache = prediction_container_.Container();
|
||||||
size_t n_targets = 1;
|
size_t n_targets = 1;
|
||||||
for (auto const& d : cache) {
|
for (auto const& d : cache) {
|
||||||
if (n_targets == 1) {
|
if (n_targets == 1) {
|
||||||
n_targets = this->obj_->Targets(d.first->Info());
|
n_targets = this->obj_->Targets(d.first.ptr->Info());
|
||||||
} else {
|
} else {
|
||||||
auto t = this->obj_->Targets(d.first->Info());
|
auto t = this->obj_->Targets(d.first.ptr->Info());
|
||||||
CHECK(n_targets == t || 1 == t) << "Inconsistent labels.";
|
CHECK(n_targets == t || 1 == t) << "Inconsistent labels.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1275,8 +1262,7 @@ class LearnerImpl : public LearnerIO {
|
|||||||
|
|
||||||
this->ValidateDMatrix(train.get(), true);
|
this->ValidateDMatrix(train.get(), true);
|
||||||
|
|
||||||
auto local_cache = this->GetPredictionCache();
|
auto& predt = prediction_container_.Cache(train, ctx_.gpu_id);
|
||||||
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
|
|
||||||
|
|
||||||
monitor_.Start("PredictRaw");
|
monitor_.Start("PredictRaw");
|
||||||
this->PredictRaw(train.get(), &predt, true, 0, 0);
|
this->PredictRaw(train.get(), &predt, true, 0, 0);
|
||||||
@ -1303,8 +1289,7 @@ class LearnerImpl : public LearnerIO {
|
|||||||
|
|
||||||
this->ValidateDMatrix(train.get(), true);
|
this->ValidateDMatrix(train.get(), true);
|
||||||
|
|
||||||
auto local_cache = this->GetPredictionCache();
|
auto& predt = prediction_container_.Cache(train, ctx_.gpu_id);
|
||||||
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
|
|
||||||
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
|
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
|
||||||
monitor_.Stop("BoostOneIter");
|
monitor_.Stop("BoostOneIter");
|
||||||
}
|
}
|
||||||
@ -1326,10 +1311,9 @@ class LearnerImpl : public LearnerIO {
|
|||||||
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
|
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto local_cache = this->GetPredictionCache();
|
|
||||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||||
std::shared_ptr<DMatrix> m = data_sets[i];
|
std::shared_ptr<DMatrix> m = data_sets[i];
|
||||||
auto &predt = local_cache->Cache(m, ctx_.gpu_id);
|
auto &predt = prediction_container_.Cache(m, ctx_.gpu_id);
|
||||||
this->ValidateDMatrix(m.get(), false);
|
this->ValidateDMatrix(m.get(), false);
|
||||||
this->PredictRaw(m.get(), &predt, false, 0, 0);
|
this->PredictRaw(m.get(), &predt, false, 0, 0);
|
||||||
|
|
||||||
@ -1370,8 +1354,7 @@ class LearnerImpl : public LearnerIO {
|
|||||||
} else if (pred_leaf) {
|
} else if (pred_leaf) {
|
||||||
gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end);
|
gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end);
|
||||||
} else {
|
} else {
|
||||||
auto local_cache = this->GetPredictionCache();
|
auto& prediction = prediction_container_.Cache(data, ctx_.gpu_id);
|
||||||
auto& prediction = local_cache->Cache(data, ctx_.gpu_id);
|
|
||||||
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
|
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
|
||||||
// Copy the prediction cache to output prediction. out_preds comes from C API
|
// Copy the prediction cache to output prediction. out_preds comes from C API
|
||||||
out_preds->SetDevice(ctx_.gpu_id);
|
out_preds->SetDevice(ctx_.gpu_id);
|
||||||
|
|||||||
@ -3,16 +3,18 @@
|
|||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/cache.h>
|
#include <xgboost/cache.h>
|
||||||
#include <xgboost/data.h> // DMatrix
|
#include <xgboost/data.h> // for DMatrix
|
||||||
|
|
||||||
#include <cstddef> // std::size_t
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for uint32_t
|
||||||
|
#include <thread> // for thread
|
||||||
|
|
||||||
#include "helpers.h" // RandomDataGenerator
|
#include "helpers.h" // for RandomDataGenerator
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace {
|
namespace {
|
||||||
struct CacheForTest {
|
struct CacheForTest {
|
||||||
std::size_t i;
|
std::size_t const i;
|
||||||
|
|
||||||
explicit CacheForTest(std::size_t k) : i{k} {}
|
explicit CacheForTest(std::size_t k) : i{k} {}
|
||||||
};
|
};
|
||||||
@ -20,7 +22,7 @@ struct CacheForTest {
|
|||||||
|
|
||||||
TEST(DMatrixCache, Basic) {
|
TEST(DMatrixCache, Basic) {
|
||||||
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 4;
|
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 4;
|
||||||
DMatrixCache<CacheForTest> cache(kCacheSize);
|
DMatrixCache<CacheForTest> cache{kCacheSize};
|
||||||
|
|
||||||
auto add_cache = [&]() {
|
auto add_cache = [&]() {
|
||||||
// Create a lambda function here, so that p_fmat gets deleted upon the
|
// Create a lambda function here, so that p_fmat gets deleted upon the
|
||||||
@ -52,4 +54,63 @@ TEST(DMatrixCache, Basic) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(DMatrixCache, MultiThread) {
|
||||||
|
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 3;
|
||||||
|
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||||
|
|
||||||
|
auto n = std::thread::hardware_concurrency() * 128u;
|
||||||
|
CHECK_NE(n, 0);
|
||||||
|
std::vector<std::shared_ptr<CacheForTest>> results(n);
|
||||||
|
|
||||||
|
{
|
||||||
|
DMatrixCache<CacheForTest> cache{kCacheSize};
|
||||||
|
std::vector<std::thread> tasks;
|
||||||
|
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
|
||||||
|
tasks.emplace_back([&, i = tidx]() {
|
||||||
|
cache.CacheItem(p_fmat, i);
|
||||||
|
|
||||||
|
auto p_fmat_local = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||||
|
results[i] = cache.CacheItem(p_fmat_local, i);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
for (auto& t : tasks) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
|
||||||
|
ASSERT_EQ(results[tidx]->i, tidx);
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.clear();
|
||||||
|
|
||||||
|
for (std::int32_t tidx = static_cast<std::int32_t>(n - 1); tidx >= 0; --tidx) {
|
||||||
|
tasks.emplace_back([&, i = tidx]() {
|
||||||
|
cache.CacheItem(p_fmat, i);
|
||||||
|
|
||||||
|
auto p_fmat_local = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||||
|
results[i] = cache.CacheItem(p_fmat_local, i);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
for (auto& t : tasks) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
|
||||||
|
ASSERT_EQ(results[tidx]->i, tidx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
DMatrixCache<CacheForTest> cache{n};
|
||||||
|
std::vector<std::thread> tasks;
|
||||||
|
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
|
||||||
|
tasks.emplace_back([&, tidx]() { results[tidx] = cache.CacheItem(p_fmat, tidx); });
|
||||||
|
}
|
||||||
|
for (auto& t : tasks) {
|
||||||
|
t.join();
|
||||||
|
}
|
||||||
|
for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
|
||||||
|
ASSERT_EQ(results[tidx]->i, tidx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user