Generalize prediction cache. (#8783)
* Extract most of the functionality into `DMatrixCache`. * Move API entry to independent file to reduce dependency on `predictor.h` file. * Add test. --------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
ed91e775ec
commit
d11a0044cf
134
include/xgboost/cache.h
Normal file
134
include/xgboost/cache.h
Normal file
@ -0,0 +1,134 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_CACHE_H_
|
||||
#define XGBOOST_CACHE_H_
|
||||
|
||||
#include <xgboost/logging.h> // CHECK_EQ
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
#include <memory> // std::weak_ptr,std::shared_ptr,std::make_shared
|
||||
#include <queue> // std:queue
|
||||
#include <unordered_map> // std::unordered_map
|
||||
#include <vector> // std::vector
|
||||
|
||||
namespace xgboost {
|
||||
class DMatrix;
|
||||
/**
|
||||
* \brief FIFO cache for DMatrix related data.
|
||||
*
|
||||
* \tparam CacheT The type that needs to be cached.
|
||||
*/
|
||||
template <typename CacheT>
|
||||
class DMatrixCache {
|
||||
public:
|
||||
struct Item {
|
||||
// A weak pointer for checking whether the DMatrix object has expired.
|
||||
std::weak_ptr<DMatrix> ref;
|
||||
// The cached item
|
||||
std::shared_ptr<CacheT> value;
|
||||
|
||||
CacheT const& Value() const { return *value; }
|
||||
CacheT& Value() { return *value; }
|
||||
};
|
||||
|
||||
protected:
|
||||
std::unordered_map<DMatrix const*, Item> container_;
|
||||
std::queue<DMatrix const*> queue_;
|
||||
std::size_t max_size_;
|
||||
|
||||
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
|
||||
|
||||
void ClearExpired() {
|
||||
// Clear expired entries
|
||||
this->CheckConsistent();
|
||||
std::vector<DMatrix const*> expired;
|
||||
std::queue<DMatrix const*> remained;
|
||||
|
||||
while (!queue_.empty()) {
|
||||
auto p_fmat = queue_.front();
|
||||
auto it = container_.find(p_fmat);
|
||||
CHECK(it != container_.cend());
|
||||
if (it->second.ref.expired()) {
|
||||
expired.push_back(it->first);
|
||||
} else {
|
||||
remained.push(it->first);
|
||||
}
|
||||
queue_.pop();
|
||||
}
|
||||
CHECK(queue_.empty());
|
||||
CHECK_EQ(remained.size() + expired.size(), container_.size());
|
||||
|
||||
for (auto const* p_fmat : expired) {
|
||||
container_.erase(p_fmat);
|
||||
}
|
||||
while (!remained.empty()) {
|
||||
auto p_fmat = remained.front();
|
||||
queue_.push(p_fmat);
|
||||
remained.pop();
|
||||
}
|
||||
this->CheckConsistent();
|
||||
}
|
||||
|
||||
void ClearExcess() {
|
||||
this->CheckConsistent();
|
||||
while (queue_.size() >= max_size_) {
|
||||
auto p_fmat = queue_.front();
|
||||
queue_.pop();
|
||||
container_.erase(p_fmat);
|
||||
}
|
||||
this->CheckConsistent();
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* \param cache_size Maximum size of the cache.
|
||||
*/
|
||||
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
|
||||
/**
|
||||
* \brief Cache a new DMatrix if it's no in the cache already.
|
||||
*
|
||||
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
|
||||
* entry this shared pointer is necessary. More importantly, the life time of this
|
||||
* cache is tied to the shared pointer.
|
||||
*
|
||||
* \param m shared pointer to the DMatrix that needs to be cached.
|
||||
* \param args The arguments for constructing a new cache item, if needed.
|
||||
*
|
||||
* \return The cache entry for passed in DMatrix, either an existing cache or newly
|
||||
* created.
|
||||
*/
|
||||
template <typename... Args>
|
||||
std::shared_ptr<CacheT>& CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
|
||||
CHECK(m);
|
||||
this->ClearExpired();
|
||||
if (container_.size() >= max_size_) {
|
||||
this->ClearExcess();
|
||||
}
|
||||
// after clear, cache size < max_size
|
||||
CHECK_LT(container_.size(), max_size_);
|
||||
auto it = container_.find(m.get());
|
||||
if (it == container_.cend()) {
|
||||
// after the new DMatrix, cache size is at most max_size
|
||||
container_[m.get()] = {m, std::make_shared<CacheT>(args...)};
|
||||
queue_.push(m.get());
|
||||
}
|
||||
return container_.at(m.get()).value;
|
||||
}
|
||||
/**
|
||||
* \brief Get a const reference to the underlying hash map. Clear expired caches before
|
||||
* returning.
|
||||
*/
|
||||
decltype(container_) const& Container() {
|
||||
this->ClearExpired();
|
||||
return container_;
|
||||
}
|
||||
|
||||
std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
|
||||
CHECK(container_.find(m) != container_.cend());
|
||||
CHECK(!container_.at(m).ref.expired());
|
||||
return container_.at(m).value;
|
||||
}
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_CACHE_H_
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2014-2023 by XGBoost Contributors
|
||||
* \file gbm.h
|
||||
* \brief Interface of gradient booster,
|
||||
* that learns through gradient statistics.
|
||||
@ -31,7 +31,6 @@ class ObjFunction;
|
||||
struct Context;
|
||||
struct LearnerModelParam;
|
||||
struct PredictionCacheEntry;
|
||||
class PredictionContainer;
|
||||
|
||||
/*!
|
||||
* \brief interface of gradient boosting model.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file learner.h
|
||||
* \brief Learner interface that integrates objective, gbm and evaluation together.
|
||||
* This is the user facing XGBoost training module.
|
||||
@ -8,12 +8,13 @@
|
||||
#ifndef XGBOOST_LEARNER_H_
|
||||
#define XGBOOST_LEARNER_H_
|
||||
|
||||
#include <dmlc/io.h> // Serializable
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/context.h> // Context
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/linalg.h> // Tensor
|
||||
#include <xgboost/model.h>
|
||||
#include <xgboost/predictor.h>
|
||||
#include <xgboost/task.h>
|
||||
|
||||
#include <map>
|
||||
@ -29,6 +30,7 @@ class GradientBooster;
|
||||
class ObjFunction;
|
||||
class DMatrix;
|
||||
class Json;
|
||||
struct XGBAPIThreadLocalEntry;
|
||||
|
||||
enum class PredictionType : std::uint8_t { // NOLINT
|
||||
kValue = 0,
|
||||
@ -40,26 +42,6 @@ enum class PredictionType : std::uint8_t { // NOLINT
|
||||
kLeaf = 6
|
||||
};
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
struct XGBAPIThreadLocalEntry {
|
||||
/*! \brief result holder for returning string */
|
||||
std::string ret_str;
|
||||
/*! \brief result holder for returning raw buffer */
|
||||
std::vector<char> ret_char_vec;
|
||||
/*! \brief result holder for returning strings */
|
||||
std::vector<std::string> ret_vec_str;
|
||||
/*! \brief result holder for returning string pointers */
|
||||
std::vector<const char *> ret_vec_charp;
|
||||
/*! \brief returning float vector. */
|
||||
std::vector<bst_float> ret_vec_float;
|
||||
/*! \brief temp variable of gradient pairs. */
|
||||
std::vector<GradientPair> tmp_gpair;
|
||||
/*! \brief Temp variable for returning prediction result. */
|
||||
PredictionCacheEntry prediction_entry;
|
||||
/*! \brief Temp variable for returning prediction shape. */
|
||||
std::vector<bst_ulong> prediction_shape;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Learner class that does training and prediction.
|
||||
* This is the user facing module of xgboost training.
|
||||
|
||||
@ -1,95 +1,66 @@
|
||||
/*!
|
||||
* Copyright 2017-2022 by Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by Contributors
|
||||
* \file predictor.h
|
||||
* \brief Interface of predictor,
|
||||
* performs predictions for a gradient booster.
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/cache.h> // DMatrixCache
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/host_device_vector.h>
|
||||
|
||||
#include <functional>
|
||||
#include <functional> // std::function
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
// Forward declarations
|
||||
namespace xgboost {
|
||||
class TreeUpdater;
|
||||
namespace gbm {
|
||||
struct GBTreeModel;
|
||||
} // namespace gbm
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
/**
|
||||
* \struct PredictionCacheEntry
|
||||
*
|
||||
* \brief Contains pointer to input matrix and associated cached predictions.
|
||||
*/
|
||||
struct PredictionCacheEntry {
|
||||
// A storage for caching prediction values
|
||||
HostDeviceVector<bst_float> predictions;
|
||||
// The version of current cache, corresponding number of layers of trees
|
||||
uint32_t version { 0 };
|
||||
// A weak pointer for checking whether the DMatrix object has expired.
|
||||
std::weak_ptr< DMatrix > ref;
|
||||
std::uint32_t version{0};
|
||||
|
||||
PredictionCacheEntry() = default;
|
||||
/* \brief Update the cache entry by number of versions.
|
||||
/**
|
||||
* \brief Update the cache entry by number of versions.
|
||||
*
|
||||
* \param v Added versions.
|
||||
*/
|
||||
void Update(uint32_t v) {
|
||||
void Update(std::uint32_t v) {
|
||||
version += v;
|
||||
}
|
||||
};
|
||||
|
||||
/* \brief A container for managed prediction caches.
|
||||
/**
|
||||
* \brief A container for managed prediction caches.
|
||||
*/
|
||||
class PredictionContainer {
|
||||
std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
|
||||
void ClearExpiredEntries();
|
||||
class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
|
||||
// we cache up to 32 DMatrix
|
||||
std::size_t static constexpr DefaultSize() { return 32; }
|
||||
|
||||
public:
|
||||
PredictionContainer() = default;
|
||||
/* \brief Add a new DMatrix to the cache, at the same time this function will clear out
|
||||
* all expired caches by checking the `std::weak_ptr`. Caching an existing
|
||||
* DMatrix won't renew it.
|
||||
*
|
||||
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
|
||||
* entry this shared pointer is necessary. More importantly, the life time of this
|
||||
* cache is tied to the shared pointer.
|
||||
*
|
||||
* Another way to make a safe cache is create a proxy to this entry, with anther shared
|
||||
* pointer defined inside, and pass this proxy around instead of the real entry. But
|
||||
* seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have
|
||||
* (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`.
|
||||
*
|
||||
* \param m shared pointer to the DMatrix that needs to be cached.
|
||||
* \param device Which device should the cache be allocated on. Pass
|
||||
* Context::kCpuId for CPU or positive integer for GPU id.
|
||||
*
|
||||
* \return the cache entry for passed in DMatrix, either an existing cache or newly
|
||||
* created.
|
||||
*/
|
||||
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device);
|
||||
/* \brief Get a prediction cache entry. This entry must be already allocated by `Cache`
|
||||
* method. Otherwise a dmlc::Error is thrown.
|
||||
*
|
||||
* \param m pointer to the DMatrix.
|
||||
* \return The prediction cache for passed in DMatrix.
|
||||
*/
|
||||
PredictionCacheEntry& Entry(DMatrix* m);
|
||||
/* \brief Get a const reference to the underlying hash map. Clear expired caches before
|
||||
* returning.
|
||||
*/
|
||||
decltype(container_) const& Container();
|
||||
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
|
||||
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
||||
this->CacheItem(m);
|
||||
auto p_cache = this->container_.find(m.get());
|
||||
if (device != Context::kCpuId) {
|
||||
p_cache->second.Value().predictions.SetDevice(device);
|
||||
}
|
||||
return p_cache->second.Value();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
@ -114,7 +85,7 @@ class Predictor {
|
||||
*
|
||||
* \param cfg The configuration.
|
||||
*/
|
||||
virtual void Configure(const std::vector<std::pair<std::string, std::string>>&);
|
||||
virtual void Configure(Args const&);
|
||||
|
||||
/**
|
||||
* \brief Initialize output prediction
|
||||
|
||||
@ -12,11 +12,11 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../common/charconv.h"
|
||||
#include "../common/io.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/simple_dmatrix.h"
|
||||
#include "c_api_error.h"
|
||||
#include "c_api_utils.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
|
||||
35
src/common/api_entry.h
Normal file
35
src/common/api_entry.h
Normal file
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2016-2023 by XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_API_ENTRY_H_
|
||||
#define XGBOOST_COMMON_API_ENTRY_H_
|
||||
#include <string> // std::string
|
||||
#include <vector> // std::vector
|
||||
|
||||
#include "xgboost/base.h" // GradientPair,bst_ulong
|
||||
#include "xgboost/predictor.h" // PredictionCacheEntry
|
||||
|
||||
namespace xgboost {
|
||||
/**
|
||||
* \brief entry to to easily hold returning information
|
||||
*/
|
||||
struct XGBAPIThreadLocalEntry {
|
||||
/*! \brief result holder for returning string */
|
||||
std::string ret_str;
|
||||
/*! \brief result holder for returning raw buffer */
|
||||
std::vector<char> ret_char_vec;
|
||||
/*! \brief result holder for returning strings */
|
||||
std::vector<std::string> ret_vec_str;
|
||||
/*! \brief result holder for returning string pointers */
|
||||
std::vector<const char *> ret_vec_charp;
|
||||
/*! \brief returning float vector. */
|
||||
std::vector<float> ret_vec_float;
|
||||
/*! \brief temp variable of gradient pairs. */
|
||||
std::vector<GradientPair> tmp_gpair;
|
||||
/*! \brief Temp variable for returning prediction result. */
|
||||
PredictionCacheEntry prediction_entry;
|
||||
/*! \brief Temp variable for returning prediction shape. */
|
||||
std::vector<bst_ulong> prediction_shape;
|
||||
};
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_API_ENTRY_H_
|
||||
@ -10,6 +10,7 @@
|
||||
#include <cstring>
|
||||
|
||||
#include "../collective/communicator-inl.h"
|
||||
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../common/group_data.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/linalg_op.h"
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "collective/communicator-inl.h"
|
||||
#include "common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "common/charconv.h"
|
||||
#include "common/common.h"
|
||||
#include "common/io.h"
|
||||
@ -430,7 +431,9 @@ class LearnerConfiguration : public Learner {
|
||||
monitor_.Init("Learner");
|
||||
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
|
||||
for (std::shared_ptr<DMatrix> const& d : cache) {
|
||||
local_cache.Cache(d, Context::kCpuId);
|
||||
if (d) {
|
||||
local_cache.Cache(d, Context::kCpuId);
|
||||
}
|
||||
}
|
||||
}
|
||||
~LearnerConfiguration() override {
|
||||
@ -1296,9 +1299,8 @@ class LearnerImpl : public LearnerIO {
|
||||
this->ValidateDMatrix(train.get(), true);
|
||||
|
||||
auto local_cache = this->GetPredictionCache();
|
||||
local_cache->Cache(train, ctx_.gpu_id);
|
||||
|
||||
gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get());
|
||||
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
|
||||
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
|
||||
monitor_.Stop("BoostOneIter");
|
||||
}
|
||||
|
||||
|
||||
@ -1,57 +1,28 @@
|
||||
/*!
|
||||
* Copyright 2017-2021 by Contributors
|
||||
/**
|
||||
* Copyright 2017-2023 by Contributors
|
||||
*/
|
||||
#include "xgboost/predictor.h"
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
#include <mutex>
|
||||
#include <string> // std::string
|
||||
|
||||
#include "../gbm/gbtree.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "../gbm/gbtree.h" // GBTreeModel
|
||||
#include "xgboost/base.h" // bst_row_t,bst_group_t
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/data.h" // MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/learner.h" // LearnerModelParam
|
||||
#include "xgboost/linalg.h" // Tensor
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
void PredictionContainer::ClearExpiredEntries() {
|
||||
std::vector<DMatrix*> expired;
|
||||
for (auto& kv : container_) {
|
||||
if (kv.second.ref.expired()) {
|
||||
expired.emplace_back(kv.first);
|
||||
}
|
||||
}
|
||||
for (auto const& ptr : expired) {
|
||||
container_.erase(ptr);
|
||||
}
|
||||
}
|
||||
void Predictor::Configure(Args const&) {}
|
||||
|
||||
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
||||
this->ClearExpiredEntries();
|
||||
container_[m.get()].ref = m;
|
||||
if (device != Context::kCpuId) {
|
||||
container_[m.get()].predictions.SetDevice(device);
|
||||
}
|
||||
return container_[m.get()];
|
||||
}
|
||||
|
||||
PredictionCacheEntry &PredictionContainer::Entry(DMatrix *m) {
|
||||
CHECK(container_.find(m) != container_.cend());
|
||||
CHECK(container_.at(m).ref.lock())
|
||||
<< "[Internal error]: DMatrix: " << m << " has expired.";
|
||||
return container_.at(m);
|
||||
}
|
||||
|
||||
decltype(PredictionContainer::container_) const& PredictionContainer::Container() {
|
||||
this->ClearExpiredEntries();
|
||||
return container_;
|
||||
}
|
||||
|
||||
void Predictor::Configure(
|
||||
const std::vector<std::pair<std::string, std::string>>&) {
|
||||
}
|
||||
Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
|
||||
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
|
||||
55
tests/cpp/test_cache.cc
Normal file
55
tests/cpp/test_cache.cc
Normal file
@ -0,0 +1,55 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/cache.h>
|
||||
#include <xgboost/data.h> // DMatrix
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
|
||||
#include "helpers.h" // RandomDataGenerator
|
||||
|
||||
namespace xgboost {
|
||||
namespace {
|
||||
struct CacheForTest {
|
||||
std::size_t i;
|
||||
|
||||
explicit CacheForTest(std::size_t k) : i{k} {}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(DMatrixCache, Basic) {
|
||||
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 4;
|
||||
DMatrixCache<CacheForTest> cache(kCacheSize);
|
||||
|
||||
auto add_cache = [&]() {
|
||||
// Create a lambda function here, so that p_fmat gets deleted upon the
|
||||
// end of the lambda. This is to test how the cache handle expired
|
||||
// cache entries.
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
cache.CacheItem(p_fmat, 3);
|
||||
DMatrix* m = p_fmat.get();
|
||||
return m;
|
||||
};
|
||||
auto m = add_cache();
|
||||
ASSERT_EQ(cache.Container().size(), 0);
|
||||
ASSERT_THROW(cache.Entry(m), dmlc::Error);
|
||||
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
auto item = cache.CacheItem(p_fmat, 1);
|
||||
ASSERT_EQ(cache.Entry(p_fmat.get())->i, 1);
|
||||
|
||||
std::vector<std::shared_ptr<DMatrix>> items;
|
||||
for (std::size_t i = 0; i < kCacheSize * 2; ++i) {
|
||||
items.emplace_back(RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix());
|
||||
cache.CacheItem(items.back(), i);
|
||||
ASSERT_EQ(cache.Entry(items.back().get())->i, i);
|
||||
ASSERT_LE(cache.Container().size(), kCacheSize);
|
||||
if (i > kCacheSize) {
|
||||
auto k = i - kCacheSize - 1;
|
||||
ASSERT_THROW(cache.Entry(items[k].get()), dmlc::Error);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2022 by XGBoost contributors
|
||||
* Copyright 2017-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/learner.h>
|
||||
@ -10,6 +10,7 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#include "../../src/common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||
#include "../../src/common/io.h"
|
||||
#include "../../src/common/linalg_op.h"
|
||||
#include "../../src/common/random.h"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user