Generalize prediction cache. (#8783)
* Extract most of the functionality into `DMatrixCache`. * Move API entry to independent file to reduce dependency on `predictor.h` file. * Add test. --------- Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
ed91e775ec
commit
d11a0044cf
134
include/xgboost/cache.h
Normal file
134
include/xgboost/cache.h
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_CACHE_H_
|
||||||
|
#define XGBOOST_CACHE_H_
|
||||||
|
|
||||||
|
#include <xgboost/logging.h> // CHECK_EQ
|
||||||
|
|
||||||
|
#include <cstddef> // std::size_t
|
||||||
|
#include <memory> // std::weak_ptr,std::shared_ptr,std::make_shared
|
||||||
|
#include <queue> // std:queue
|
||||||
|
#include <unordered_map> // std::unordered_map
|
||||||
|
#include <vector> // std::vector
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
class DMatrix;
|
||||||
|
/**
|
||||||
|
* \brief FIFO cache for DMatrix related data.
|
||||||
|
*
|
||||||
|
* \tparam CacheT The type that needs to be cached.
|
||||||
|
*/
|
||||||
|
template <typename CacheT>
|
||||||
|
class DMatrixCache {
|
||||||
|
public:
|
||||||
|
struct Item {
|
||||||
|
// A weak pointer for checking whether the DMatrix object has expired.
|
||||||
|
std::weak_ptr<DMatrix> ref;
|
||||||
|
// The cached item
|
||||||
|
std::shared_ptr<CacheT> value;
|
||||||
|
|
||||||
|
CacheT const& Value() const { return *value; }
|
||||||
|
CacheT& Value() { return *value; }
|
||||||
|
};
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::unordered_map<DMatrix const*, Item> container_;
|
||||||
|
std::queue<DMatrix const*> queue_;
|
||||||
|
std::size_t max_size_;
|
||||||
|
|
||||||
|
void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
|
||||||
|
|
||||||
|
void ClearExpired() {
|
||||||
|
// Clear expired entries
|
||||||
|
this->CheckConsistent();
|
||||||
|
std::vector<DMatrix const*> expired;
|
||||||
|
std::queue<DMatrix const*> remained;
|
||||||
|
|
||||||
|
while (!queue_.empty()) {
|
||||||
|
auto p_fmat = queue_.front();
|
||||||
|
auto it = container_.find(p_fmat);
|
||||||
|
CHECK(it != container_.cend());
|
||||||
|
if (it->second.ref.expired()) {
|
||||||
|
expired.push_back(it->first);
|
||||||
|
} else {
|
||||||
|
remained.push(it->first);
|
||||||
|
}
|
||||||
|
queue_.pop();
|
||||||
|
}
|
||||||
|
CHECK(queue_.empty());
|
||||||
|
CHECK_EQ(remained.size() + expired.size(), container_.size());
|
||||||
|
|
||||||
|
for (auto const* p_fmat : expired) {
|
||||||
|
container_.erase(p_fmat);
|
||||||
|
}
|
||||||
|
while (!remained.empty()) {
|
||||||
|
auto p_fmat = remained.front();
|
||||||
|
queue_.push(p_fmat);
|
||||||
|
remained.pop();
|
||||||
|
}
|
||||||
|
this->CheckConsistent();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ClearExcess() {
|
||||||
|
this->CheckConsistent();
|
||||||
|
while (queue_.size() >= max_size_) {
|
||||||
|
auto p_fmat = queue_.front();
|
||||||
|
queue_.pop();
|
||||||
|
container_.erase(p_fmat);
|
||||||
|
}
|
||||||
|
this->CheckConsistent();
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* \param cache_size Maximum size of the cache.
|
||||||
|
*/
|
||||||
|
explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
|
||||||
|
/**
|
||||||
|
* \brief Cache a new DMatrix if it's no in the cache already.
|
||||||
|
*
|
||||||
|
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
|
||||||
|
* entry this shared pointer is necessary. More importantly, the life time of this
|
||||||
|
* cache is tied to the shared pointer.
|
||||||
|
*
|
||||||
|
* \param m shared pointer to the DMatrix that needs to be cached.
|
||||||
|
* \param args The arguments for constructing a new cache item, if needed.
|
||||||
|
*
|
||||||
|
* \return The cache entry for passed in DMatrix, either an existing cache or newly
|
||||||
|
* created.
|
||||||
|
*/
|
||||||
|
template <typename... Args>
|
||||||
|
std::shared_ptr<CacheT>& CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
|
||||||
|
CHECK(m);
|
||||||
|
this->ClearExpired();
|
||||||
|
if (container_.size() >= max_size_) {
|
||||||
|
this->ClearExcess();
|
||||||
|
}
|
||||||
|
// after clear, cache size < max_size
|
||||||
|
CHECK_LT(container_.size(), max_size_);
|
||||||
|
auto it = container_.find(m.get());
|
||||||
|
if (it == container_.cend()) {
|
||||||
|
// after the new DMatrix, cache size is at most max_size
|
||||||
|
container_[m.get()] = {m, std::make_shared<CacheT>(args...)};
|
||||||
|
queue_.push(m.get());
|
||||||
|
}
|
||||||
|
return container_.at(m.get()).value;
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* \brief Get a const reference to the underlying hash map. Clear expired caches before
|
||||||
|
* returning.
|
||||||
|
*/
|
||||||
|
decltype(container_) const& Container() {
|
||||||
|
this->ClearExpired();
|
||||||
|
return container_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
|
||||||
|
CHECK(container_.find(m) != container_.cend());
|
||||||
|
CHECK(!container_.at(m).ref.expired());
|
||||||
|
return container_.at(m).value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_CACHE_H_
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2014-2022 by XGBoost Contributors
|
* Copyright 2014-2023 by XGBoost Contributors
|
||||||
* \file gbm.h
|
* \file gbm.h
|
||||||
* \brief Interface of gradient booster,
|
* \brief Interface of gradient booster,
|
||||||
* that learns through gradient statistics.
|
* that learns through gradient statistics.
|
||||||
@ -31,7 +31,6 @@ class ObjFunction;
|
|||||||
struct Context;
|
struct Context;
|
||||||
struct LearnerModelParam;
|
struct LearnerModelParam;
|
||||||
struct PredictionCacheEntry;
|
struct PredictionCacheEntry;
|
||||||
class PredictionContainer;
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief interface of gradient boosting model.
|
* \brief interface of gradient boosting model.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2015-2022 by XGBoost Contributors
|
* Copyright 2015-2023 by XGBoost Contributors
|
||||||
* \file learner.h
|
* \file learner.h
|
||||||
* \brief Learner interface that integrates objective, gbm and evaluation together.
|
* \brief Learner interface that integrates objective, gbm and evaluation together.
|
||||||
* This is the user facing XGBoost training module.
|
* This is the user facing XGBoost training module.
|
||||||
@ -8,12 +8,13 @@
|
|||||||
#ifndef XGBOOST_LEARNER_H_
|
#ifndef XGBOOST_LEARNER_H_
|
||||||
#define XGBOOST_LEARNER_H_
|
#define XGBOOST_LEARNER_H_
|
||||||
|
|
||||||
|
#include <dmlc/io.h> // Serializable
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
#include <xgboost/context.h> // Context
|
#include <xgboost/context.h> // Context
|
||||||
#include <xgboost/feature_map.h>
|
#include <xgboost/feature_map.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
|
#include <xgboost/linalg.h> // Tensor
|
||||||
#include <xgboost/model.h>
|
#include <xgboost/model.h>
|
||||||
#include <xgboost/predictor.h>
|
|
||||||
#include <xgboost/task.h>
|
#include <xgboost/task.h>
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
@ -29,6 +30,7 @@ class GradientBooster;
|
|||||||
class ObjFunction;
|
class ObjFunction;
|
||||||
class DMatrix;
|
class DMatrix;
|
||||||
class Json;
|
class Json;
|
||||||
|
struct XGBAPIThreadLocalEntry;
|
||||||
|
|
||||||
enum class PredictionType : std::uint8_t { // NOLINT
|
enum class PredictionType : std::uint8_t { // NOLINT
|
||||||
kValue = 0,
|
kValue = 0,
|
||||||
@ -40,26 +42,6 @@ enum class PredictionType : std::uint8_t { // NOLINT
|
|||||||
kLeaf = 6
|
kLeaf = 6
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief entry to to easily hold returning information */
|
|
||||||
struct XGBAPIThreadLocalEntry {
|
|
||||||
/*! \brief result holder for returning string */
|
|
||||||
std::string ret_str;
|
|
||||||
/*! \brief result holder for returning raw buffer */
|
|
||||||
std::vector<char> ret_char_vec;
|
|
||||||
/*! \brief result holder for returning strings */
|
|
||||||
std::vector<std::string> ret_vec_str;
|
|
||||||
/*! \brief result holder for returning string pointers */
|
|
||||||
std::vector<const char *> ret_vec_charp;
|
|
||||||
/*! \brief returning float vector. */
|
|
||||||
std::vector<bst_float> ret_vec_float;
|
|
||||||
/*! \brief temp variable of gradient pairs. */
|
|
||||||
std::vector<GradientPair> tmp_gpair;
|
|
||||||
/*! \brief Temp variable for returning prediction result. */
|
|
||||||
PredictionCacheEntry prediction_entry;
|
|
||||||
/*! \brief Temp variable for returning prediction shape. */
|
|
||||||
std::vector<bst_ulong> prediction_shape;
|
|
||||||
};
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Learner class that does training and prediction.
|
* \brief Learner class that does training and prediction.
|
||||||
* This is the user facing module of xgboost training.
|
* This is the user facing module of xgboost training.
|
||||||
|
|||||||
@ -1,95 +1,66 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 by Contributors
|
* Copyright 2017-2023 by Contributors
|
||||||
* \file predictor.h
|
* \file predictor.h
|
||||||
* \brief Interface of predictor,
|
* \brief Interface of predictor,
|
||||||
* performs predictions for a gradient booster.
|
* performs predictions for a gradient booster.
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <xgboost/base.h>
|
#include <xgboost/base.h>
|
||||||
|
#include <xgboost/cache.h> // DMatrixCache
|
||||||
#include <xgboost/context.h>
|
#include <xgboost/context.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
#include <xgboost/host_device_vector.h>
|
#include <xgboost/host_device_vector.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional> // std::function
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Forward declarations
|
// Forward declarations
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
class TreeUpdater;
|
|
||||||
namespace gbm {
|
namespace gbm {
|
||||||
struct GBTreeModel;
|
struct GBTreeModel;
|
||||||
} // namespace gbm
|
} // namespace gbm
|
||||||
}
|
} // namespace xgboost
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
/**
|
/**
|
||||||
* \struct PredictionCacheEntry
|
|
||||||
*
|
|
||||||
* \brief Contains pointer to input matrix and associated cached predictions.
|
* \brief Contains pointer to input matrix and associated cached predictions.
|
||||||
*/
|
*/
|
||||||
struct PredictionCacheEntry {
|
struct PredictionCacheEntry {
|
||||||
// A storage for caching prediction values
|
// A storage for caching prediction values
|
||||||
HostDeviceVector<bst_float> predictions;
|
HostDeviceVector<bst_float> predictions;
|
||||||
// The version of current cache, corresponding number of layers of trees
|
// The version of current cache, corresponding number of layers of trees
|
||||||
uint32_t version { 0 };
|
std::uint32_t version{0};
|
||||||
// A weak pointer for checking whether the DMatrix object has expired.
|
|
||||||
std::weak_ptr< DMatrix > ref;
|
|
||||||
|
|
||||||
PredictionCacheEntry() = default;
|
PredictionCacheEntry() = default;
|
||||||
/* \brief Update the cache entry by number of versions.
|
/**
|
||||||
|
* \brief Update the cache entry by number of versions.
|
||||||
*
|
*
|
||||||
* \param v Added versions.
|
* \param v Added versions.
|
||||||
*/
|
*/
|
||||||
void Update(uint32_t v) {
|
void Update(std::uint32_t v) {
|
||||||
version += v;
|
version += v;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* \brief A container for managed prediction caches.
|
/**
|
||||||
|
* \brief A container for managed prediction caches.
|
||||||
*/
|
*/
|
||||||
class PredictionContainer {
|
class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
|
||||||
std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
|
// we cache up to 32 DMatrix
|
||||||
void ClearExpiredEntries();
|
std::size_t static constexpr DefaultSize() { return 32; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
PredictionContainer() = default;
|
PredictionContainer() : DMatrixCache<PredictionCacheEntry>{DefaultSize()} {}
|
||||||
/* \brief Add a new DMatrix to the cache, at the same time this function will clear out
|
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
||||||
* all expired caches by checking the `std::weak_ptr`. Caching an existing
|
this->CacheItem(m);
|
||||||
* DMatrix won't renew it.
|
auto p_cache = this->container_.find(m.get());
|
||||||
*
|
if (device != Context::kCpuId) {
|
||||||
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
|
p_cache->second.Value().predictions.SetDevice(device);
|
||||||
* entry this shared pointer is necessary. More importantly, the life time of this
|
}
|
||||||
* cache is tied to the shared pointer.
|
return p_cache->second.Value();
|
||||||
*
|
}
|
||||||
* Another way to make a safe cache is create a proxy to this entry, with anther shared
|
|
||||||
* pointer defined inside, and pass this proxy around instead of the real entry. But
|
|
||||||
* seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have
|
|
||||||
* (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`.
|
|
||||||
*
|
|
||||||
* \param m shared pointer to the DMatrix that needs to be cached.
|
|
||||||
* \param device Which device should the cache be allocated on. Pass
|
|
||||||
* Context::kCpuId for CPU or positive integer for GPU id.
|
|
||||||
*
|
|
||||||
* \return the cache entry for passed in DMatrix, either an existing cache or newly
|
|
||||||
* created.
|
|
||||||
*/
|
|
||||||
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device);
|
|
||||||
/* \brief Get a prediction cache entry. This entry must be already allocated by `Cache`
|
|
||||||
* method. Otherwise a dmlc::Error is thrown.
|
|
||||||
*
|
|
||||||
* \param m pointer to the DMatrix.
|
|
||||||
* \return The prediction cache for passed in DMatrix.
|
|
||||||
*/
|
|
||||||
PredictionCacheEntry& Entry(DMatrix* m);
|
|
||||||
/* \brief Get a const reference to the underlying hash map. Clear expired caches before
|
|
||||||
* returning.
|
|
||||||
*/
|
|
||||||
decltype(container_) const& Container();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -114,7 +85,7 @@ class Predictor {
|
|||||||
*
|
*
|
||||||
* \param cfg The configuration.
|
* \param cfg The configuration.
|
||||||
*/
|
*/
|
||||||
virtual void Configure(const std::vector<std::pair<std::string, std::string>>&);
|
virtual void Configure(Args const&);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Initialize output prediction
|
* \brief Initialize output prediction
|
||||||
|
|||||||
@ -12,11 +12,11 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
|
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||||
#include "../common/charconv.h"
|
#include "../common/charconv.h"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
#include "../data/simple_dmatrix.h"
|
#include "../data/simple_dmatrix.h"
|
||||||
#include "c_api_error.h"
|
|
||||||
#include "c_api_utils.h"
|
#include "c_api_utils.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2023 by XGBoost Contributors
|
* Copyright 2019-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
|
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "../data/device_adapter.cuh"
|
#include "../data/device_adapter.cuh"
|
||||||
#include "../data/proxy_dmatrix.h"
|
#include "../data/proxy_dmatrix.h"
|
||||||
|
|||||||
35
src/common/api_entry.h
Normal file
35
src/common/api_entry.h
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2016-2023 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_COMMON_API_ENTRY_H_
|
||||||
|
#define XGBOOST_COMMON_API_ENTRY_H_
|
||||||
|
#include <string> // std::string
|
||||||
|
#include <vector> // std::vector
|
||||||
|
|
||||||
|
#include "xgboost/base.h" // GradientPair,bst_ulong
|
||||||
|
#include "xgboost/predictor.h" // PredictionCacheEntry
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
/**
|
||||||
|
* \brief entry to to easily hold returning information
|
||||||
|
*/
|
||||||
|
struct XGBAPIThreadLocalEntry {
|
||||||
|
/*! \brief result holder for returning string */
|
||||||
|
std::string ret_str;
|
||||||
|
/*! \brief result holder for returning raw buffer */
|
||||||
|
std::vector<char> ret_char_vec;
|
||||||
|
/*! \brief result holder for returning strings */
|
||||||
|
std::vector<std::string> ret_vec_str;
|
||||||
|
/*! \brief result holder for returning string pointers */
|
||||||
|
std::vector<const char *> ret_vec_charp;
|
||||||
|
/*! \brief returning float vector. */
|
||||||
|
std::vector<float> ret_vec_float;
|
||||||
|
/*! \brief temp variable of gradient pairs. */
|
||||||
|
std::vector<GradientPair> tmp_gpair;
|
||||||
|
/*! \brief Temp variable for returning prediction result. */
|
||||||
|
PredictionCacheEntry prediction_entry;
|
||||||
|
/*! \brief Temp variable for returning prediction shape. */
|
||||||
|
std::vector<bst_ulong> prediction_shape;
|
||||||
|
};
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_COMMON_API_ENTRY_H_
|
||||||
@ -10,6 +10,7 @@
|
|||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
|
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||||
#include "../common/group_data.h"
|
#include "../common/group_data.h"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../common/linalg_op.h"
|
#include "../common/linalg_op.h"
|
||||||
|
|||||||
@ -25,6 +25,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "collective/communicator-inl.h"
|
#include "collective/communicator-inl.h"
|
||||||
|
#include "common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||||
#include "common/charconv.h"
|
#include "common/charconv.h"
|
||||||
#include "common/common.h"
|
#include "common/common.h"
|
||||||
#include "common/io.h"
|
#include "common/io.h"
|
||||||
@ -430,7 +431,9 @@ class LearnerConfiguration : public Learner {
|
|||||||
monitor_.Init("Learner");
|
monitor_.Init("Learner");
|
||||||
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
|
auto& local_cache = (*ThreadLocalPredictionCache::Get())[this];
|
||||||
for (std::shared_ptr<DMatrix> const& d : cache) {
|
for (std::shared_ptr<DMatrix> const& d : cache) {
|
||||||
local_cache.Cache(d, Context::kCpuId);
|
if (d) {
|
||||||
|
local_cache.Cache(d, Context::kCpuId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
~LearnerConfiguration() override {
|
~LearnerConfiguration() override {
|
||||||
@ -1296,9 +1299,8 @@ class LearnerImpl : public LearnerIO {
|
|||||||
this->ValidateDMatrix(train.get(), true);
|
this->ValidateDMatrix(train.get(), true);
|
||||||
|
|
||||||
auto local_cache = this->GetPredictionCache();
|
auto local_cache = this->GetPredictionCache();
|
||||||
local_cache->Cache(train, ctx_.gpu_id);
|
auto& predt = local_cache->Cache(train, ctx_.gpu_id);
|
||||||
|
gbm_->DoBoost(train.get(), in_gpair, &predt, obj_.get());
|
||||||
gbm_->DoBoost(train.get(), in_gpair, &local_cache->Entry(train.get()), obj_.get());
|
|
||||||
monitor_.Stop("BoostOneIter");
|
monitor_.Stop("BoostOneIter");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,57 +1,28 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2021 by Contributors
|
* Copyright 2017-2023 by Contributors
|
||||||
*/
|
*/
|
||||||
#include "xgboost/predictor.h"
|
#include "xgboost/predictor.h"
|
||||||
|
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
|
|
||||||
#include <mutex>
|
#include <string> // std::string
|
||||||
|
|
||||||
#include "../gbm/gbtree.h"
|
#include "../gbm/gbtree.h" // GBTreeModel
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/base.h" // bst_row_t,bst_group_t
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/context.h" // Context
|
||||||
|
#include "xgboost/data.h" // MetaInfo
|
||||||
|
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||||
|
#include "xgboost/learner.h" // LearnerModelParam
|
||||||
|
#include "xgboost/linalg.h" // Tensor
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
namespace dmlc {
|
namespace dmlc {
|
||||||
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
|
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
|
||||||
} // namespace dmlc
|
} // namespace dmlc
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
void PredictionContainer::ClearExpiredEntries() {
|
void Predictor::Configure(Args const&) {}
|
||||||
std::vector<DMatrix*> expired;
|
|
||||||
for (auto& kv : container_) {
|
|
||||||
if (kv.second.ref.expired()) {
|
|
||||||
expired.emplace_back(kv.first);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto const& ptr : expired) {
|
|
||||||
container_.erase(ptr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
|
|
||||||
this->ClearExpiredEntries();
|
|
||||||
container_[m.get()].ref = m;
|
|
||||||
if (device != Context::kCpuId) {
|
|
||||||
container_[m.get()].predictions.SetDevice(device);
|
|
||||||
}
|
|
||||||
return container_[m.get()];
|
|
||||||
}
|
|
||||||
|
|
||||||
PredictionCacheEntry &PredictionContainer::Entry(DMatrix *m) {
|
|
||||||
CHECK(container_.find(m) != container_.cend());
|
|
||||||
CHECK(container_.at(m).ref.lock())
|
|
||||||
<< "[Internal error]: DMatrix: " << m << " has expired.";
|
|
||||||
return container_.at(m);
|
|
||||||
}
|
|
||||||
|
|
||||||
decltype(PredictionContainer::container_) const& PredictionContainer::Container() {
|
|
||||||
this->ClearExpiredEntries();
|
|
||||||
return container_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Predictor::Configure(
|
|
||||||
const std::vector<std::pair<std::string, std::string>>&) {
|
|
||||||
}
|
|
||||||
Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
|
Predictor* Predictor::Create(std::string const& name, Context const* ctx) {
|
||||||
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
|
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
|
||||||
if (e == nullptr) {
|
if (e == nullptr) {
|
||||||
|
|||||||
55
tests/cpp/test_cache.cc
Normal file
55
tests/cpp/test_cache.cc
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2023 by XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/cache.h>
|
||||||
|
#include <xgboost/data.h> // DMatrix
|
||||||
|
|
||||||
|
#include <cstddef> // std::size_t
|
||||||
|
|
||||||
|
#include "helpers.h" // RandomDataGenerator
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace {
|
||||||
|
struct CacheForTest {
|
||||||
|
std::size_t i;
|
||||||
|
|
||||||
|
explicit CacheForTest(std::size_t k) : i{k} {}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(DMatrixCache, Basic) {
|
||||||
|
std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 4;
|
||||||
|
DMatrixCache<CacheForTest> cache(kCacheSize);
|
||||||
|
|
||||||
|
auto add_cache = [&]() {
|
||||||
|
// Create a lambda function here, so that p_fmat gets deleted upon the
|
||||||
|
// end of the lambda. This is to test how the cache handle expired
|
||||||
|
// cache entries.
|
||||||
|
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||||
|
cache.CacheItem(p_fmat, 3);
|
||||||
|
DMatrix* m = p_fmat.get();
|
||||||
|
return m;
|
||||||
|
};
|
||||||
|
auto m = add_cache();
|
||||||
|
ASSERT_EQ(cache.Container().size(), 0);
|
||||||
|
ASSERT_THROW(cache.Entry(m), dmlc::Error);
|
||||||
|
|
||||||
|
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||||
|
|
||||||
|
auto item = cache.CacheItem(p_fmat, 1);
|
||||||
|
ASSERT_EQ(cache.Entry(p_fmat.get())->i, 1);
|
||||||
|
|
||||||
|
std::vector<std::shared_ptr<DMatrix>> items;
|
||||||
|
for (std::size_t i = 0; i < kCacheSize * 2; ++i) {
|
||||||
|
items.emplace_back(RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix());
|
||||||
|
cache.CacheItem(items.back(), i);
|
||||||
|
ASSERT_EQ(cache.Entry(items.back().get())->i, i);
|
||||||
|
ASSERT_LE(cache.Container().size(), kCacheSize);
|
||||||
|
if (i > kCacheSize) {
|
||||||
|
auto k = i - kCacheSize - 1;
|
||||||
|
ASSERT_THROW(cache.Entry(items[k].get()), dmlc::Error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2017-2022 by XGBoost contributors
|
* Copyright 2017-2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
#include <xgboost/learner.h>
|
#include <xgboost/learner.h>
|
||||||
@ -10,6 +10,7 @@
|
|||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../../src/common/api_entry.h" // XGBAPIThreadLocalEntry
|
||||||
#include "../../src/common/io.h"
|
#include "../../src/common/io.h"
|
||||||
#include "../../src/common/linalg_op.h"
|
#include "../../src/common/linalg_op.h"
|
||||||
#include "../../src/common/random.h"
|
#include "../../src/common/random.h"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user