Move prediction cache to Learner. (#5220)

* Move prediction cache into Learner.

* Clean-ups

- Remove duplicated cache in Learner and GBM.
- Remove ad-hoc fix of invalid cache.
- Remove `PredictFromCache` in predictors.
- Remove prediction cache for linear altogether, as it's only moving the
  prediction into training process but doesn't provide any actual overall speed
  gain.
- The cache is now unique to Learner, which means the ownership is no longer
  shared by any other components.

* Changes

- Add version to prediction cache.
- Use weak ptr to check expired DMatrix.
- Pass shared pointer instead of raw pointer.
This commit is contained in:
Jiaming Yuan 2020-02-14 13:04:23 +08:00 committed by GitHub
parent 24ad9dec0b
commit c35cdecddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 457 additions and 372 deletions

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright by Contributors * Copyright 2014-2020 by Contributors
* \file gbm.h * \file gbm.h
* \brief Interface of gradient booster, * \brief Interface of gradient booster,
* that learns through gradient statistics. * that learns through gradient statistics.
@ -18,6 +18,7 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <functional> #include <functional>
#include <unordered_map>
#include <memory> #include <memory>
namespace xgboost { namespace xgboost {
@ -28,6 +29,8 @@ class ObjFunction;
struct GenericParameter; struct GenericParameter;
struct LearnerModelParam; struct LearnerModelParam;
struct PredictionCacheEntry;
class PredictionContainer;
/*! /*!
* \brief interface of gradient boosting model. * \brief interface of gradient boosting model.
@ -38,7 +41,7 @@ class GradientBooster : public Model, public Configurable {
public: public:
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~GradientBooster() = default; ~GradientBooster() override = default;
/*! /*!
* \brief Set the configuration of gradient boosting. * \brief Set the configuration of gradient boosting.
* User must call configure once before InitModel and Training. * User must call configure once before InitModel and Training.
@ -71,19 +74,22 @@ class GradientBooster : public Model, public Configurable {
* \param obj The objective function, optional, can be nullptr when use customized version * \param obj The objective function, optional, can be nullptr when use customized version
* the booster may change content of gpair * the booster may change content of gpair
*/ */
virtual void DoBoost(DMatrix* p_fmat, virtual void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
HostDeviceVector<GradientPair>* in_gpair, PredictionCacheEntry *prediction) = 0;
ObjFunction* obj = nullptr) = 0;
/*! /*!
* \brief generate predictions for given feature matrix * \brief generate predictions for given feature matrix
* \param dmat feature matrix * \param dmat feature matrix
* \param out_preds output vector to hold the predictions * \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means * \param training Whether the prediction value is used for training. For dart booster
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear * drop out is performed during training.
* \param ntree_limit limit the number of trees used in prediction,
* when it equals 0, this means we do not limit
* number of trees, this parameter is only valid
* for gbtree, but not for gblinear
*/ */
virtual void PredictBatch(DMatrix* dmat, virtual void PredictBatch(DMatrix* dmat,
HostDeviceVector<bst_float>* out_preds, PredictionCacheEntry* out_preds,
bool training, bool training,
unsigned ntree_limit = 0) = 0; unsigned ntree_limit = 0) = 0;
/*! /*!
@ -158,8 +164,7 @@ class GradientBooster : public Model, public Configurable {
static GradientBooster* Create( static GradientBooster* Create(
const std::string& name, const std::string& name,
GenericParameter const* generic_param, GenericParameter const* generic_param,
LearnerModelParam const* learner_model_param, LearnerModelParam const* learner_model_param);
const std::vector<std::shared_ptr<DMatrix> >& cache_mats);
static void AssertGPUSupport() { static void AssertGPUSupport() {
#ifndef XGBOOST_USE_CUDA #ifndef XGBOOST_USE_CUDA
@ -174,8 +179,7 @@ class GradientBooster : public Model, public Configurable {
struct GradientBoosterReg struct GradientBoosterReg
: public dmlc::FunctionRegEntryBase< : public dmlc::FunctionRegEntryBase<
GradientBoosterReg, GradientBoosterReg,
std::function<GradientBooster* (const std::vector<std::shared_ptr<DMatrix> > &cached_mats, std::function<GradientBooster* (LearnerModelParam const* learner_model_param)> > {
LearnerModelParam const* learner_model_param)> > {
}; };
/*! /*!

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright by Contributors * Copyright 2017-2020 by Contributors
* \file predictor.h * \file predictor.h
* \brief Interface of predictor, * \brief Interface of predictor,
* performs predictions for a gradient booster. * performs predictions for a gradient booster.
@ -32,47 +32,83 @@ namespace xgboost {
* \brief Contains pointer to input matrix and associated cached predictions. * \brief Contains pointer to input matrix and associated cached predictions.
*/ */
struct PredictionCacheEntry { struct PredictionCacheEntry {
std::shared_ptr<DMatrix> data; // A storage for caching prediction values
HostDeviceVector<bst_float> predictions; HostDeviceVector<bst_float> predictions;
// The version of current cache, corresponding number of layers of trees
uint32_t version;
// A weak pointer for checking whether the DMatrix object has expired.
std::weak_ptr< DMatrix > ref;
PredictionCacheEntry() : version { 0 } {}
/* \brief Update the cache entry by number of versions.
*
* \param v Added versions.
*/
void Update(uint32_t v) {
version += v;
}
};
/* \brief A container for managed prediction caches.
*/
class PredictionContainer {
std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
void ClearExpiredEntries();
public:
PredictionContainer() = default;
/* \brief Add a new DMatrix to the cache, at the same time this function will clear out
* all expired caches by checking the `std::weak_ptr`. Caching an existing
* DMatrix won't renew it.
*
* Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
* entry this shared pointer is necessary. More importantly, the life time of this
* cache is tied to the shared pointer.
*
* Another way to make a safe cache is create a proxy to this entry, with anther shared
* pointer defined inside, and pass this proxy around instead of the real entry. But
* seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have
* (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`.
*
* \param m shared pointer to the DMatrix that needs to be cached.
* \param device Which device should the cache be allocated on. Pass
* GenericParameter::kCpuId for CPU or positive integer for GPU id.
*
* \return the cache entry for passed in DMatrix, either an existing cache or newly
* created.
*/
PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device);
/* \brief Get a prediction cache entry. This entry must be already allocated by `Cache`
* method. Otherwise a dmlc::Error is thrown.
*
* \param m pointer to the DMatrix.
* \return The prediction cache for passed in DMatrix.
*/
PredictionCacheEntry& Entry(DMatrix* m);
/* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
*/
decltype(container_) const& Container();
}; };
/** /**
* \class Predictor * \class Predictor
* *
* \brief Performs prediction on individual training instances or batches of * \brief Performs prediction on individual training instances or batches of instances for
* instances for GBTree. The predictor also manages a prediction cache * GBTree. Prediction functions all take a GBTreeModel and a DMatrix as input and
* associated with input matrices. If possible, it will use previously * output a vector of predictions. The predictor does not modify any state of the
* calculated predictions instead of calculating new predictions. * model itself.
* Prediction functions all take a GBTreeModel and a DMatrix as input and
* output a vector of predictions. The predictor does not modify any state of
* the model itself.
*/ */
class Predictor { class Predictor {
protected: protected:
/* /*
* \brief Runtime parameters. * \brief Runtime parameters.
*/ */
GenericParameter const* generic_param_; GenericParameter const* generic_param_;
/**
* \brief Map of matrices and associated cached predictions to facilitate
* storing and looking up predictions.
*/
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache_;
std::unordered_map<DMatrix*, PredictionCacheEntry>::iterator FindCache(DMatrix const* dmat) {
auto cache_emtry = std::find_if(
cache_->begin(), cache_->end(),
[dmat](std::pair<DMatrix *, PredictionCacheEntry const &> const &kv) {
return kv.second.data.get() == dmat;
});
return cache_emtry;
}
public: public:
Predictor(GenericParameter const* generic_param, explicit Predictor(GenericParameter const* generic_param) :
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache) : generic_param_{generic_param} {}
generic_param_{generic_param}, cache_{cache} {}
virtual ~Predictor() = default; virtual ~Predictor() = default;
/** /**
@ -93,10 +129,9 @@ class Predictor {
* \param ntree_limit (Optional) The ntree limit. 0 means do not * \param ntree_limit (Optional) The ntree limit. 0 means do not
* limit trees. * limit trees.
*/ */
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
virtual void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, int tree_begin, const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) = 0; uint32_t const ntree_limit = 0) = 0;
/** /**
* \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel * \fn virtual void Predictor::UpdatePredictionCache( const gbm::GBTreeModel
@ -116,7 +151,9 @@ class Predictor {
virtual void UpdatePredictionCache( virtual void UpdatePredictionCache(
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
std::vector<std::unique_ptr<TreeUpdater>>* updaters, std::vector<std::unique_ptr<TreeUpdater>>* updaters,
int num_new_trees) = 0; int num_new_trees,
DMatrix* m,
PredictionCacheEntry* predts) = 0;
/** /**
* \fn virtual void Predictor::PredictInstance( const SparsePage::Inst& * \fn virtual void Predictor::PredictInstance( const SparsePage::Inst&
@ -200,8 +237,7 @@ class Predictor {
* \param cache Pointer to prediction cache. * \param cache Pointer to prediction cache.
*/ */
static Predictor* Create( static Predictor* Create(
std::string const& name, GenericParameter const* generic_param, std::string const& name, GenericParameter const* generic_param);
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache);
}; };
/*! /*!
@ -209,9 +245,7 @@ class Predictor {
*/ */
struct PredictorReg struct PredictorReg
: public dmlc::FunctionRegEntryBase< : public dmlc::FunctionRegEntryBase<
PredictorReg, std::function<Predictor*( PredictorReg, std::function<Predictor*(GenericParameter const*)>> {};
GenericParameter const*,
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>>)>> {};
#define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \ #define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \
static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \ static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \

View File

@ -158,7 +158,7 @@ class RegTree : public Model {
} }
/*! \brief whether this node is deleted */ /*! \brief whether this node is deleted */
XGBOOST_DEVICE bool IsDeleted() const { XGBOOST_DEVICE bool IsDeleted() const {
return sindex_ == std::numeric_limits<unsigned>::max(); return sindex_ == std::numeric_limits<uint32_t>::max();
} }
/*! \brief whether current node is root */ /*! \brief whether current node is root */
XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; } XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }

View File

@ -15,6 +15,7 @@
#include "xgboost/gbm.h" #include "xgboost/gbm.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/predictor.h"
#include "xgboost/linear_updater.h" #include "xgboost/linear_updater.h"
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/learner.h" #include "xgboost/learner.h"
@ -50,21 +51,14 @@ struct GBLinearTrainParam : public XGBoostParameter<GBLinearTrainParam> {
*/ */
class GBLinear : public GradientBooster { class GBLinear : public GradientBooster {
public: public:
explicit GBLinear(const std::vector<std::shared_ptr<DMatrix> > &cache, explicit GBLinear(LearnerModelParam const* learner_model_param)
LearnerModelParam const* learner_model_param)
: learner_model_param_{learner_model_param}, : learner_model_param_{learner_model_param},
model_{learner_model_param_}, model_{learner_model_param_},
previous_model_{learner_model_param_}, previous_model_{learner_model_param_},
sum_instance_weight_(0), sum_instance_weight_(0),
sum_weight_complete_(false), sum_weight_complete_(false),
is_converged_(false) { is_converged_(false) {}
// Add matrices to the prediction cache
for (auto &d : cache) {
PredictionCacheEntry e;
e.data = d;
cache_[d.get()] = std::move(e);
}
}
void Configure(const Args& cfg) override { void Configure(const Args& cfg) override {
if (model_.weight.size() == 0) { if (model_.weight.size() == 0) {
model_.Configure(cfg); model_.Configure(cfg);
@ -118,7 +112,7 @@ class GBLinear : public GradientBooster {
void DoBoost(DMatrix *p_fmat, void DoBoost(DMatrix *p_fmat,
HostDeviceVector<GradientPair> *in_gpair, HostDeviceVector<GradientPair> *in_gpair,
ObjFunction* obj) override { PredictionCacheEntry* predt) override {
monitor_.Start("DoBoost"); monitor_.Start("DoBoost");
model_.LazyInitModel(); model_.LazyInitModel();
@ -127,28 +121,19 @@ class GBLinear : public GradientBooster {
if (!this->CheckConvergence()) { if (!this->CheckConvergence()) {
updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_); updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_);
} }
this->UpdatePredictionCache();
monitor_.Stop("DoBoost"); monitor_.Stop("DoBoost");
} }
void PredictBatch(DMatrix *p_fmat, void PredictBatch(DMatrix *p_fmat,
HostDeviceVector<bst_float> *out_preds, PredictionCacheEntry *predts,
bool training, bool training,
unsigned ntree_limit) override { unsigned ntree_limit) override {
monitor_.Start("PredictBatch"); monitor_.Start("PredictBatch");
auto* out_preds = &predts->predictions;
CHECK_EQ(ntree_limit, 0U) CHECK_EQ(ntree_limit, 0U)
<< "GBLinear::Predict ntrees is only valid for gbtree predictor"; << "GBLinear::Predict ntrees is only valid for gbtree predictor";
// Try to predict from cache
auto it = cache_.find(p_fmat);
if (it != cache_.end() && it->second.predictions.size() != 0) {
std::vector<bst_float> &y = it->second.predictions;
out_preds->Resize(y.size());
std::copy(y.begin(), y.end(), out_preds->HostVector().begin());
} else {
this->PredictBatchInternal(p_fmat, &out_preds->HostVector()); this->PredictBatchInternal(p_fmat, &out_preds->HostVector());
}
monitor_.Stop("PredictBatch"); monitor_.Stop("PredictBatch");
} }
// add base margin // add base margin
@ -258,7 +243,8 @@ class GBLinear : public GradientBooster {
const size_t ridx = batch.base_rowid + i; const size_t ridx = batch.base_rowid + i;
// loop over output groups // loop over output groups
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
bst_float margin = (base_margin.size() != 0) ? bst_float margin =
(base_margin.size() != 0) ?
base_margin[ridx * ngroup + gid] : learner_model_param_->base_score; base_margin[ridx * ngroup + gid] : learner_model_param_->base_score;
this->Pred(batch[i], &preds[ridx * ngroup], gid, margin); this->Pred(batch[i], &preds[ridx * ngroup], gid, margin);
} }
@ -266,17 +252,6 @@ class GBLinear : public GradientBooster {
} }
monitor_.Stop("PredictBatchInternal"); monitor_.Stop("PredictBatchInternal");
} }
void UpdatePredictionCache() {
// update cache entry
for (auto &kv : cache_) {
PredictionCacheEntry &e = kv.second;
if (e.predictions.size() == 0) {
size_t n = model_.learner_model_param_->num_output_group * e.data->Info().num_row_;
e.predictions.resize(n);
}
this->PredictBatchInternal(e.data.get(), &e.predictions);
}
}
bool CheckConvergence() { bool CheckConvergence() {
if (param_.tolerance == 0.0f) return false; if (param_.tolerance == 0.0f) return false;
@ -327,22 +302,6 @@ class GBLinear : public GradientBooster {
bool sum_weight_complete_; bool sum_weight_complete_;
common::Monitor monitor_; common::Monitor monitor_;
bool is_converged_; bool is_converged_;
/**
* \struct PredictionCacheEntry
*
* \brief Contains pointer to input matrix and associated cached predictions.
*/
struct PredictionCacheEntry {
std::shared_ptr<DMatrix> data;
std::vector<bst_float> predictions;
};
/**
* \brief Map of matrices and associated cached predictions to facilitate
* storing and looking up predictions.
*/
std::unordered_map<DMatrix*, PredictionCacheEntry> cache_;
}; };
// register the objective functions // register the objective functions
@ -350,9 +309,8 @@ DMLC_REGISTER_PARAMETER(GBLinearTrainParam);
XGBOOST_REGISTER_GBM(GBLinear, "gblinear") XGBOOST_REGISTER_GBM(GBLinear, "gblinear")
.describe("Linear booster, implement generalized linear model.") .describe("Linear booster, implement generalized linear model.")
.set_body([](const std::vector<std::shared_ptr<DMatrix> > &cache, .set_body([](LearnerModelParam const* booster_config) {
LearnerModelParam const* booster_config) { return new GBLinear(booster_config);
return new GBLinear(cache, booster_config);
}); });
} // namespace gbm } // namespace gbm
} // namespace xgboost } // namespace xgboost

View File

@ -55,8 +55,9 @@ class GBLinearModel : public Model {
std::vector<bst_float> weight; std::vector<bst_float> weight;
// initialize the model parameter // initialize the model parameter
inline void LazyInitModel() { inline void LazyInitModel() {
if (!weight.empty()) if (!weight.empty()) {
return; return;
}
// bias is the last weight // bias is the last weight
weight.resize((learner_model_param_->num_feature + 1) * weight.resize((learner_model_param_->num_feature + 1) *
learner_model_param_->num_output_group); learner_model_param_->num_output_group);

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015 by Contributors * Copyright 2015-2020 by Contributors
* \file gbm.cc * \file gbm.cc
* \brief Registry of gradient boosters. * \brief Registry of gradient boosters.
*/ */
@ -20,13 +20,12 @@ namespace xgboost {
GradientBooster* GradientBooster::Create( GradientBooster* GradientBooster::Create(
const std::string& name, const std::string& name,
GenericParameter const* generic_param, GenericParameter const* generic_param,
LearnerModelParam const* learner_model_param, LearnerModelParam const* learner_model_param) {
const std::vector<std::shared_ptr<DMatrix> >& cache_mats) {
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name); auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
if (e == nullptr) { if (e == nullptr) {
LOG(FATAL) << "Unknown gbm type " << name; LOG(FATAL) << "Unknown gbm type " << name;
} }
auto p_bst = (e->body)(cache_mats, learner_model_param); auto p_bst = (e->body)(learner_model_param);
p_bst->generic_param_ = generic_param; p_bst->generic_param_ = generic_param;
return p_bst; return p_bst;
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2019 by Contributors * Copyright 2014-2020 by Contributors
* \file gbtree.cc * \file gbtree.cc
* \brief gradient boosted tree implementation. * \brief gradient boosted tree implementation.
* \author Tianqi Chen * \author Tianqi Chen
@ -14,6 +14,7 @@
#include <limits> #include <limits>
#include <algorithm> #include <algorithm>
#include "xgboost/data.h"
#include "xgboost/gbm.h" #include "xgboost/gbm.h"
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/json.h" #include "xgboost/json.h"
@ -47,14 +48,14 @@ void GBTree::Configure(const Args& cfg) {
// configure predictors // configure predictors
if (!cpu_predictor_) { if (!cpu_predictor_) {
cpu_predictor_ = std::unique_ptr<Predictor>( cpu_predictor_ = std::unique_ptr<Predictor>(
Predictor::Create("cpu_predictor", this->generic_param_, cache_)); Predictor::Create("cpu_predictor", this->generic_param_));
} }
cpu_predictor_->Configure(cfg); cpu_predictor_->Configure(cfg);
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
auto n_gpus = common::AllVisibleGPUs(); auto n_gpus = common::AllVisibleGPUs();
if (!gpu_predictor_ && n_gpus != 0) { if (!gpu_predictor_ && n_gpus != 0) {
gpu_predictor_ = std::unique_ptr<Predictor>( gpu_predictor_ = std::unique_ptr<Predictor>(
Predictor::Create("gpu_predictor", this->generic_param_, cache_)); Predictor::Create("gpu_predictor", this->generic_param_));
} }
if (n_gpus != 0) { if (n_gpus != 0) {
gpu_predictor_->Configure(cfg); gpu_predictor_->Configure(cfg);
@ -183,7 +184,7 @@ void GBTree::ConfigureUpdaters() {
void GBTree::DoBoost(DMatrix* p_fmat, void GBTree::DoBoost(DMatrix* p_fmat,
HostDeviceVector<GradientPair>* in_gpair, HostDeviceVector<GradientPair>* in_gpair,
ObjFunction* obj) { PredictionCacheEntry* predt) {
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees; std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
const int ngroup = model_.learner_model_param_->num_output_group; const int ngroup = model_.learner_model_param_->num_output_group;
ConfigureWithKnownData(this->cfg_, p_fmat); ConfigureWithKnownData(this->cfg_, p_fmat);
@ -214,7 +215,7 @@ void GBTree::DoBoost(DMatrix* p_fmat,
} }
} }
monitor_.Stop("BoostNewTrees"); monitor_.Stop("BoostNewTrees");
this->CommitModel(std::move(new_trees)); this->CommitModel(std::move(new_trees), p_fmat, predt);
} }
void GBTree::InitUpdater(Args const& cfg) { void GBTree::InitUpdater(Args const& cfg) {
@ -286,7 +287,9 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
} }
} }
void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) { void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees,
DMatrix* m,
PredictionCacheEntry* predts) {
monitor_.Start("CommitModel"); monitor_.Start("CommitModel");
int num_new_trees = 0; int num_new_trees = 0;
for (uint32_t gid = 0; gid < model_.learner_model_param_->num_output_group; ++gid) { for (uint32_t gid = 0; gid < model_.learner_model_param_->num_output_group; ++gid) {
@ -294,7 +297,7 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
model_.CommitModel(std::move(new_trees[gid]), gid); model_.CommitModel(std::move(new_trees[gid]), gid);
} }
CHECK(configured_); CHECK(configured_);
GetPredictor()->UpdatePredictionCache(model_, &updaters_, num_new_trees); GetPredictor()->UpdatePredictionCache(model_, &updaters_, num_new_trees, m, predts);
monitor_.Stop("CommitModel"); monitor_.Stop("CommitModel");
} }
@ -303,13 +306,16 @@ void GBTree::LoadConfig(Json const& in) {
fromJson(in["gbtree_train_param"], &tparam_); fromJson(in["gbtree_train_param"], &tparam_);
int32_t const n_gpus = xgboost::common::AllVisibleGPUs(); int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) { if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) {
LOG(WARNING)
<< "Loading from a raw memory buffer on CPU only machine. "
"Changing predictor to auto.";
tparam_.UpdateAllowUnknown(Args{{"predictor", "auto"}}); tparam_.UpdateAllowUnknown(Args{{"predictor", "auto"}});
} }
if (n_gpus == 0 && tparam_.tree_method == TreeMethod::kGPUHist) { if (n_gpus == 0 && tparam_.tree_method == TreeMethod::kGPUHist) {
tparam_.UpdateAllowUnknown(Args{{"tree_method", "hist"}}); tparam_.UpdateAllowUnknown(Args{{"tree_method", "hist"}});
LOG(WARNING) LOG(WARNING)
<< "Loading from a raw memory buffer on CPU only machine. " << "Loading from a raw memory buffer on CPU only machine. "
"Change tree_method to hist."; "Changing tree_method to hist.";
} }
auto const& j_updaters = get<Object const>(in["updater"]); auto const& j_updaters = get<Object const>(in["updater"]);
@ -415,7 +421,7 @@ class Dart : public GBTree {
} }
void PredictBatch(DMatrix* p_fmat, void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* p_out_preds, PredictionCacheEntry* p_out_preds,
bool training, bool training,
unsigned ntree_limit) override { unsigned ntree_limit) override {
DropTrees(training); DropTrees(training);
@ -426,7 +432,7 @@ class Dart : public GBTree {
} }
size_t n = num_group * p_fmat->Info().num_row_; size_t n = num_group * p_fmat->Info().num_row_;
const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector(); const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector();
auto& out_preds = p_out_preds->HostVector(); auto& out_preds = p_out_preds->predictions.HostVector();
out_preds.resize(n); out_preds.resize(n);
if (base_margin.size() != 0) { if (base_margin.size() != 0) {
CHECK_EQ(out_preds.size(), n); CHECK_EQ(out_preds.size(), n);
@ -539,7 +545,9 @@ class Dart : public GBTree {
// commit new trees all at once // commit new trees all at once
void void
CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) override { CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees,
DMatrix* m,
PredictionCacheEntry* predts) override {
int num_new_trees = 0; int num_new_trees = 0;
for (uint32_t gid = 0; gid < model_.learner_model_param_->num_output_group; ++gid) { for (uint32_t gid = 0; gid < model_.learner_model_param_->num_output_group; ++gid) {
num_new_trees += new_trees[gid].size(); num_new_trees += new_trees[gid].size();
@ -681,16 +689,13 @@ DMLC_REGISTER_PARAMETER(DartTrainParam);
XGBOOST_REGISTER_GBM(GBTree, "gbtree") XGBOOST_REGISTER_GBM(GBTree, "gbtree")
.describe("Tree booster, gradient boosted trees.") .describe("Tree booster, gradient boosted trees.")
.set_body([](const std::vector<std::shared_ptr<DMatrix> >& cached_mats, .set_body([](LearnerModelParam const* booster_config) {
LearnerModelParam const* booster_config) {
auto* p = new GBTree(booster_config); auto* p = new GBTree(booster_config);
p->InitCache(cached_mats);
return p; return p;
}); });
XGBOOST_REGISTER_GBM(Dart, "dart") XGBOOST_REGISTER_GBM(Dart, "dart")
.describe("Tree booster, dart.") .describe("Tree booster, dart.")
.set_body([](const std::vector<std::shared_ptr<DMatrix> >& cached_mats, .set_body([](LearnerModelParam const* booster_config) {
LearnerModelParam const* booster_config) {
GBTree* p = new Dart(booster_config); GBTree* p = new Dart(booster_config);
return p; return p;
}); });

View File

@ -16,6 +16,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "xgboost/data.h"
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/gbm.h" #include "xgboost/gbm.h"
#include "xgboost/predictor.h" #include "xgboost/predictor.h"
@ -151,14 +152,8 @@ struct DartTrainParam : public XGBoostParameter<DartTrainParam> {
// gradient boosted trees // gradient boosted trees
class GBTree : public GradientBooster { class GBTree : public GradientBooster {
public: public:
explicit GBTree(LearnerModelParam const* booster_config) : model_(booster_config) {} explicit GBTree(LearnerModelParam const* booster_config) :
model_(booster_config) {}
void InitCache(const std::vector<std::shared_ptr<DMatrix> > &cache) {
cache_ = std::make_shared<std::unordered_map<DMatrix*, PredictionCacheEntry>>();
for (std::shared_ptr<DMatrix> const& d : cache) {
(*cache_)[d.get()].data = d;
}
}
void Configure(const Args& cfg) override; void Configure(const Args& cfg) override;
// Revise `tree_method` and `updater` parameters after seeing the training // Revise `tree_method` and `updater` parameters after seeing the training
@ -171,7 +166,7 @@ class GBTree : public GradientBooster {
/*! \brief Carry out one iteration of boosting */ /*! \brief Carry out one iteration of boosting */
void DoBoost(DMatrix* p_fmat, void DoBoost(DMatrix* p_fmat,
HostDeviceVector<GradientPair>* in_gpair, HostDeviceVector<GradientPair>* in_gpair,
ObjFunction* obj) override; PredictionCacheEntry* predt) override;
bool UseGPU() const override { bool UseGPU() const override {
return return
@ -204,11 +199,12 @@ class GBTree : public GradientBooster {
} }
void PredictBatch(DMatrix* p_fmat, void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds, PredictionCacheEntry* out_preds,
bool training, bool training,
unsigned ntree_limit) override { unsigned ntree_limit) override {
CHECK(configured_); CHECK(configured_);
GetPredictor(out_preds, p_fmat)->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit); GetPredictor(&out_preds->predictions, p_fmat)->PredictBatch(
p_fmat, out_preds, model_, 0, ntree_limit);
} }
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst& inst,
@ -318,7 +314,9 @@ class GBTree : public GradientBooster {
} }
// commit new trees all at once // commit new trees all at once
virtual void CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees); virtual void CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees,
DMatrix* m,
PredictionCacheEntry* predts);
// --- data structure --- // --- data structure ---
GBTreeModel model_; GBTreeModel model_;
@ -332,11 +330,6 @@ class GBTree : public GradientBooster {
Args cfg_; Args cfg_;
// the updaters that can be applied to each of tree // the updaters that can be applied to each of tree
std::vector<std::unique_ptr<TreeUpdater>> updaters_; std::vector<std::unique_ptr<TreeUpdater>> updaters_;
/**
* \brief Map of matrices and associated cached predictions to facilitate
* storing and looking up predictions.
*/
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache_;
// Predictors // Predictors
std::unique_ptr<Predictor> cpu_predictor_; std::unique_ptr<Predictor> cpu_predictor_;
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)

View File

@ -10,6 +10,7 @@
#include <algorithm> #include <algorithm>
#include <iomanip> #include <iomanip>
#include <limits> #include <limits>
#include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <stack> #include <stack>
@ -17,6 +18,8 @@
#include <vector> #include <vector>
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/predictor.h"
#include "xgboost/feature_map.h" #include "xgboost/feature_map.h"
#include "xgboost/gbm.h" #include "xgboost/gbm.h"
#include "xgboost/generic_parameters.h" #include "xgboost/generic_parameters.h"
@ -196,8 +199,11 @@ void GenericParameter::ConfigureGpuId(bool require_gpu) {
class LearnerImpl : public Learner { class LearnerImpl : public Learner {
public: public:
explicit LearnerImpl(std::vector<std::shared_ptr<DMatrix> > cache) explicit LearnerImpl(std::vector<std::shared_ptr<DMatrix> > cache)
: need_configuration_{true}, cache_(std::move(cache)) { : need_configuration_{true} {
monitor_.Init("Learner"); monitor_.Init("Learner");
for (std::shared_ptr<DMatrix> const& d : cache) {
cache_.Cache(d, GenericParameter::kCpuId);
}
} }
// Configuration before data is known. // Configuration before data is known.
void Configure() override { void Configure() override {
@ -358,8 +364,7 @@ class LearnerImpl : public Learner {
name = get<String>(gradient_booster["name"]); name = get<String>(gradient_booster["name"]);
tparam_.UpdateAllowUnknown(Args{{"booster", name}}); tparam_.UpdateAllowUnknown(Args{{"booster", name}});
gbm_.reset(GradientBooster::Create(tparam_.booster, gbm_.reset(GradientBooster::Create(tparam_.booster,
&generic_parameters_, &learner_model_param_, &generic_parameters_, &learner_model_param_));
cache_));
gbm_->LoadModel(gradient_booster); gbm_->LoadModel(gradient_booster);
auto const& j_attributes = get<Object const>(learner.at("attributes")); auto const& j_attributes = get<Object const>(learner.at("attributes"));
@ -413,8 +418,7 @@ class LearnerImpl : public Learner {
tparam_.booster = get<String>(gradient_booster["name"]); tparam_.booster = get<String>(gradient_booster["name"]);
if (!gbm_) { if (!gbm_) {
gbm_.reset(GradientBooster::Create(tparam_.booster, gbm_.reset(GradientBooster::Create(tparam_.booster,
&generic_parameters_, &learner_model_param_, &generic_parameters_, &learner_model_param_));
cache_));
} }
gbm_->LoadConfig(gradient_booster); gbm_->LoadConfig(gradient_booster);
@ -500,7 +504,7 @@ class LearnerImpl : public Learner {
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_)); obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_));
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_, gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_,
&learner_model_param_, cache_)); &learner_model_param_));
gbm_->Load(fi); gbm_->Load(fi);
if (mparam_.contain_extra_attrs != 0) { if (mparam_.contain_extra_attrs != 0) {
std::vector<std::pair<std::string, std::string> > attr; std::vector<std::pair<std::string, std::string> > attr;
@ -726,17 +730,18 @@ class LearnerImpl : public Learner {
this->CheckDataSplitMode(); this->CheckDataSplitMode();
this->ValidateDMatrix(train.get()); this->ValidateDMatrix(train.get());
auto& predt = this->cache_.Cache(train, generic_parameters_.gpu_id);
monitor_.Start("PredictRaw"); monitor_.Start("PredictRaw");
this->PredictRaw(train.get(), &preds_[train.get()], true); this->PredictRaw(train.get(), &predt, true);
monitor_.Stop("PredictRaw"); monitor_.Stop("PredictRaw");
TrainingObserver::Instance().Observe(preds_[train.get()], "Predictions");
monitor_.Start("GetGradient"); monitor_.Start("GetGradient");
obj_->GetGradient(preds_[train.get()], train->Info(), iter, &gpair_); obj_->GetGradient(predt.predictions, train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient"); monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(gpair_, "Gradients"); TrainingObserver::Instance().Observe(gpair_, "Gradients");
gbm_->DoBoost(train.get(), &gpair_, obj_.get()); gbm_->DoBoost(train.get(), &gpair_, &predt);
monitor_.Stop("UpdateOneIter"); monitor_.Stop("UpdateOneIter");
} }
@ -749,12 +754,14 @@ class LearnerImpl : public Learner {
} }
this->CheckDataSplitMode(); this->CheckDataSplitMode();
this->ValidateDMatrix(train.get()); this->ValidateDMatrix(train.get());
this->cache_.Cache(train, generic_parameters_.gpu_id);
gbm_->DoBoost(train.get(), in_gpair); gbm_->DoBoost(train.get(), in_gpair, &cache_.Entry(train.get()));
monitor_.Stop("BoostOneIter"); monitor_.Stop("BoostOneIter");
} }
std::string EvalOneIter(int iter, const std::vector<std::shared_ptr<DMatrix>>& data_sets, std::string EvalOneIter(int iter,
const std::vector<std::shared_ptr<DMatrix>>& data_sets,
const std::vector<std::string>& data_names) override { const std::vector<std::string>& data_names) override {
monitor_.Start("EvalOneIter"); monitor_.Start("EvalOneIter");
this->Configure(); this->Configure();
@ -766,14 +773,19 @@ class LearnerImpl : public Learner {
metrics_.back()->Configure({cfg_.begin(), cfg_.end()}); metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
} }
for (size_t i = 0; i < data_sets.size(); ++i) { for (size_t i = 0; i < data_sets.size(); ++i) {
DMatrix * dmat = data_sets[i].get(); std::shared_ptr<DMatrix> m = data_sets[i];
this->ValidateDMatrix(dmat); auto &predt = this->cache_.Cache(m, generic_parameters_.gpu_id);
this->PredictRaw(dmat, &preds_[dmat], false); this->ValidateDMatrix(m.get());
obj_->EvalTransform(&preds_[dmat]); this->PredictRaw(m.get(), &predt, false);
auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions;
out.Resize(predt.predictions.Size());
out.Copy(predt.predictions);
obj_->EvalTransform(&out);
for (auto& ev : metrics_) { for (auto& ev : metrics_) {
os << '\t' << data_names[i] << '-' << ev->Name() << ':' os << '\t' << data_names[i] << '-' << ev->Name() << ':'
<< ev->Eval(preds_[dmat], data_sets[i]->Info(), << ev->Eval(out, m->Info(), tparam_.dsplit == DataSplitMode::kRow);
tparam_.dsplit == DataSplitMode::kRow);
} }
} }
@ -848,7 +860,12 @@ class LearnerImpl : public Learner {
} else if (pred_leaf) { } else if (pred_leaf) {
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit); gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);
} else { } else {
this->PredictRaw(data.get(), out_preds, training, ntree_limit); auto& prediction = cache_.Cache(data, generic_parameters_.gpu_id);
this->PredictRaw(data.get(), &prediction, training, ntree_limit);
// Copy the prediction cache to output prediction. out_preds comes from C API
out_preds->SetDevice(generic_parameters_.gpu_id);
out_preds->Resize(prediction.predictions.Size());
out_preds->Copy(prediction.predictions);
if (!output_margin) { if (!output_margin) {
obj_->PredTransform(out_preds); obj_->PredTransform(out_preds);
} }
@ -868,11 +885,10 @@ class LearnerImpl : public Learner {
* predictor, when it equals 0, this means we are using all the trees * predictor, when it equals 0, this means we are using all the trees
* \param training allow dropout when the DART booster is being used * \param training allow dropout when the DART booster is being used
*/ */
void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds, void PredictRaw(DMatrix* data, PredictionCacheEntry* out_preds,
bool training, bool training,
unsigned ntree_limit = 0) const { unsigned ntree_limit = 0) const {
CHECK(gbm_ != nullptr) CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
<< "Predict must happen after Load or configuration";
this->ValidateDMatrix(data); this->ValidateDMatrix(data);
gbm_->PredictBatch(data, out_preds, training, ntree_limit); gbm_->PredictBatch(data, out_preds, training, ntree_limit);
} }
@ -920,7 +936,7 @@ class LearnerImpl : public Learner {
void ConfigureGBM(LearnerTrainParam const& old, Args const& args) { void ConfigureGBM(LearnerTrainParam const& old, Args const& args) {
if (gbm_ == nullptr || old.booster != tparam_.booster) { if (gbm_ == nullptr || old.booster != tparam_.booster) {
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_, gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_,
&learner_model_param_, cache_)); &learner_model_param_));
} }
gbm_->Configure(args); gbm_->Configure(args);
} }
@ -930,9 +946,10 @@ class LearnerImpl : public Learner {
// estimate feature bound // estimate feature bound
// TODO(hcho3): Change num_feature to 64-bit integer // TODO(hcho3): Change num_feature to 64-bit integer
unsigned num_feature = 0; unsigned num_feature = 0;
for (auto & matrix : cache_) { for (auto & matrix : cache_.Container()) {
CHECK(matrix != nullptr); CHECK(matrix.first);
const uint64_t num_col = matrix->Info().num_col_; CHECK(!matrix.second.ref.expired());
const uint64_t num_col = matrix.first->Info().num_col_;
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max())) CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
<< "Unfortunately, XGBoost does not support data matrices with " << "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<unsigned>::max() << " features or greater"; << std::numeric_limits<unsigned>::max() << " features or greater";
@ -990,13 +1007,12 @@ class LearnerImpl : public Learner {
// `enable_experimental_json_serialization' is set to false. Will be removed once JSON // `enable_experimental_json_serialization' is set to false. Will be removed once JSON
// takes over. // takes over.
std::string const serialisation_header_ { u8"CONFIG-offset:" }; std::string const serialisation_header_ { u8"CONFIG-offset:" };
// configurations // User provided configurations
std::map<std::string, std::string> cfg_; std::map<std::string, std::string> cfg_;
// Stores information like best-iteration for early stopping.
std::map<std::string, std::string> attributes_; std::map<std::string, std::string> attributes_;
std::vector<std::string> metric_names_; std::vector<std::string> metric_names_;
static std::string const kEvalMetric; // NOLINT static std::string const kEvalMetric; // NOLINT
// temporal storages for prediction
std::map<DMatrix*, HostDeviceVector<bst_float>> preds_;
// gradient pairs // gradient pairs
HostDeviceVector<GradientPair> gpair_; HostDeviceVector<GradientPair> gpair_;
bool need_configuration_; bool need_configuration_;
@ -1004,8 +1020,11 @@ class LearnerImpl : public Learner {
private: private:
/*! \brief random number transformation seed. */ /*! \brief random number transformation seed. */
static int32_t constexpr kRandSeedMagic = 127; static int32_t constexpr kRandSeedMagic = 127;
// internal cached dmatrix // internal cached dmatrix for prediction.
std::vector<std::shared_ptr<DMatrix> > cache_; PredictionContainer cache_;
/*! \brief Temporary storage to prediction. Useful for storing data transformed by
* objective function */
PredictionContainer output_predictions_;
common::Monitor monitor_; common::Monitor monitor_;

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright by Contributors 2017-2019 * Copyright by Contributors 2017-2020
*/ */
#include <dmlc/omp.h> #include <dmlc/omp.h>
@ -46,7 +46,7 @@ class CPUPredictor : public Predictor {
} }
} }
void PredLoopInternal(DMatrix* p_fmat, std::vector<bst_float>* out_preds, void PredInternal(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, gbm::GBTreeModel const &model, int32_t tree_begin,
int32_t tree_end) { int32_t tree_end) {
int32_t const num_group = model.learner_model_param_->num_output_group; int32_t const num_group = model.learner_model_param_->num_output_group;
@ -102,27 +102,6 @@ class CPUPredictor : public Predictor {
} }
} }
bool PredictFromCache(DMatrix* dmat,
HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model,
unsigned ntree_limit) const {
CHECK(cache_);
if (ntree_limit == 0 ||
ntree_limit * model.learner_model_param_->num_output_group >= model.trees.size()) {
auto it = cache_->find(dmat);
if (it != cache_->end()) {
const HostDeviceVector<bst_float>& y = it->second.predictions;
if (y.Size() != 0) {
out_preds->Resize(y.Size());
std::copy(y.HostVector().begin(), y.HostVector().end(),
out_preds->HostVector().begin());
return true;
}
}
}
return false;
}
void InitOutPredictions(const MetaInfo& info, void InitOutPredictions(const MetaInfo& info,
HostDeviceVector<bst_float>* out_preds, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model) const { const gbm::GBTreeModel& model) const {
@ -156,60 +135,78 @@ class CPUPredictor : public Predictor {
} }
public: public:
CPUPredictor(GenericParameter const* generic_param, explicit CPUPredictor(GenericParameter const* generic_param) :
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache) : Predictor::Predictor{generic_param} {}
Predictor::Predictor{generic_param, cache} {} // ntree_limit is a very problematic parameter, as it's ambiguous in the context of
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds, // multi-output and forest. Same problem exists for tree_begin
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
const gbm::GBTreeModel& model, int tree_begin, const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) override { uint32_t const ntree_limit = 0) override {
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { // tree_begin is not used, right now we just enforce it to be 0.
return; CHECK_EQ(tree_begin, 0);
} auto* out_preds = &predts->predictions;
CHECK_GE(predts->version, tree_begin);
if (predts->version == 0) {
CHECK_EQ(out_preds->Size(), 0);
this->InitOutPredictions(dmat->Info(), out_preds, model); this->InitOutPredictions(dmat->Info(), out_preds, model);
ntree_limit *= model.learner_model_param_->num_output_group;
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
ntree_limit = static_cast<unsigned>(model.trees.size());
} }
this->PredLoopInternal(dmat, &out_preds->HostVector(), model, uint32_t const output_groups = model.learner_model_param_->num_output_group;
tree_begin, ntree_limit); CHECK_NE(output_groups, 0);
// Right now we just assume ntree_limit provided by users means number of tree layers
// in the context of multi-output model
uint32_t real_ntree_limit = ntree_limit * output_groups;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
}
auto cache_entry = this->FindCache(dmat); uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups;
if (cache_entry == cache_->cend()) { // When users have provided ntree_limit, end_version can be lesser, cache is violated
return; if (predts->version > end_version) {
CHECK_NE(ntree_limit, 0);
this->InitOutPredictions(dmat->Info(), out_preds, model);
predts->version = 0;
} }
if (cache_entry->second.predictions.Size() == 0) { uint32_t const beg_version = predts->version;
// See comment in GPUPredictor::PredictBatch. CHECK_LE(beg_version, end_version);
InitOutPredictions(cache_entry->second.data->Info(),
&(cache_entry->second.predictions), model); if (beg_version < end_version) {
cache_entry->second.predictions.Copy(*out_preds); this->PredInternal(dmat, &out_preds->HostVector(), model,
beg_version * output_groups,
end_version * output_groups);
} }
// delta means {size of forest} * {number of newly accumulated layers}
uint32_t delta = end_version - beg_version;
CHECK_LE(delta, model.trees.size());
predts->Update(delta);
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
out_preds->Size() == dmat->Info().num_row_);
} }
void UpdatePredictionCache( void UpdatePredictionCache(
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
std::vector<std::unique_ptr<TreeUpdater>>* updaters, std::vector<std::unique_ptr<TreeUpdater>>* updaters,
int num_new_trees) override { int num_new_trees,
DMatrix* m,
PredictionCacheEntry* predts) override {
int old_ntree = model.trees.size() - num_new_trees; int old_ntree = model.trees.size() - num_new_trees;
// update cache entry // update cache entry
for (auto& kv : (*cache_)) { auto* out = &predts->predictions;
PredictionCacheEntry& e = kv.second; if (predts->predictions.Size() == 0) {
this->InitOutPredictions(m->Info(), out, model);
if (e.predictions.Size() == 0) { this->PredInternal(m, &out->HostVector(), model, 0, model.trees.size());
InitOutPredictions(e.data->Info(), &(e.predictions), model); } else if (model.learner_model_param_->num_output_group == 1 &&
PredLoopInternal(e.data.get(), &(e.predictions.HostVector()), model, 0, updaters->size() > 0 &&
model.trees.size());
} else if (model.learner_model_param_->num_output_group == 1 && updaters->size() > 0 &&
num_new_trees == 1 && num_new_trees == 1 &&
updaters->back()->UpdatePredictionCache(e.data.get(), updaters->back()->UpdatePredictionCache(m, out)) {
&(e.predictions))) { {}
{} // do nothing
} else { } else {
PredLoopInternal(e.data.get(), &(e.predictions.HostVector()), model, old_ntree, PredInternal(m, &out->HostVector(), model, old_ntree, model.trees.size());
model.trees.size());
}
} }
auto delta = num_new_trees / model.learner_model_param_->num_output_group;
predts->Update(delta);
} }
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst& inst,
@ -387,9 +384,8 @@ class CPUPredictor : public Predictor {
XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor") XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor")
.describe("Make predictions using CPU.") .describe("Make predictions using CPU.")
.set_body([](GenericParameter const* generic_param, .set_body([](GenericParameter const* generic_param) {
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache) { return new CPUPredictor(generic_param);
return new CPUPredictor(generic_param, cache);
}); });
} // namespace predictor } // namespace predictor
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2018 by Contributors * Copyright 2017-2020 by Contributors
*/ */
#include <thrust/copy.h> #include <thrust/copy.h>
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
@ -295,9 +295,8 @@ class GPUPredictor : public xgboost::Predictor {
} }
public: public:
GPUPredictor(GenericParameter const* generic_param, explicit GPUPredictor(GenericParameter const* generic_param) :
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache) : Predictor::Predictor{generic_param} {}
Predictor::Predictor{generic_param, cache} {}
~GPUPredictor() override { ~GPUPredictor() override {
if (generic_param_->gpu_id >= 0) { if (generic_param_->gpu_id >= 0) {
@ -305,43 +304,53 @@ class GPUPredictor : public xgboost::Predictor {
} }
} }
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds, void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
const gbm::GBTreeModel& model, int tree_begin, const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) override { unsigned ntree_limit = 0) override {
// This function is duplicated with CPU predictor PredictBatch, see comments in there.
// FIXME(trivialfis): Remove the duplication.
int device = generic_param_->gpu_id; int device = generic_param_->gpu_id;
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data."; CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
ConfigureDevice(device); ConfigureDevice(device);
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) { CHECK_EQ(tree_begin, 0);
return; auto* out_preds = &predts->predictions;
} CHECK_GE(predts->version, tree_begin);
if (predts->version == 0) {
CHECK_EQ(out_preds->Size(), 0);
this->InitOutPredictions(dmat->Info(), out_preds, model); this->InitOutPredictions(dmat->Info(), out_preds, model);
int32_t tree_end = ntree_limit * model.learner_model_param_->num_output_group;
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
tree_end = static_cast<unsigned>(model.trees.size());
} }
DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end); uint32_t const output_groups = model.learner_model_param_->num_output_group;
CHECK_NE(output_groups, 0);
auto cache_emtry = this->FindCache(dmat); uint32_t real_ntree_limit = ntree_limit * output_groups;
if (cache_emtry == cache_->cend()) { return; } if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
if (cache_emtry->second.predictions.Size() == 0) { real_ntree_limit = static_cast<uint32_t>(model.trees.size());
// Initialise the cache on first iteration, this comes useful
// when performing training continuation:
//
// 1. PredictBatch
// 2. CommitModel
// - updater->UpdatePredictionCache
//
// If we don't initialise this cache, the 2 step will recieve an invalid cache as
// the first step only modifies prediction store in learner without following code.
InitOutPredictions(cache_emtry->second.data->Info(),
&(cache_emtry->second.predictions), model);
CHECK_EQ(cache_emtry->second.predictions.Size(), out_preds->Size());
cache_emtry->second.predictions.Copy(*out_preds);
} }
uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups;
if (predts->version > end_version) {
CHECK_NE(ntree_limit, 0);
this->InitOutPredictions(dmat->Info(), out_preds, model);
predts->version = 0;
}
uint32_t const beg_version = predts->version;
CHECK_LE(beg_version, end_version);
if (beg_version < end_version) {
this->DevicePredictInternal(dmat, out_preds, model,
beg_version * output_groups,
end_version * output_groups);
}
uint32_t delta = end_version - beg_version;
CHECK_LE(delta, model.trees.size());
predts->Update(delta);
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
out_preds->Size() == dmat->Info().num_row_);
} }
protected: protected:
@ -361,49 +370,30 @@ class GPUPredictor : public xgboost::Predictor {
} }
} }
bool PredictFromCache(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit) {
if (ntree_limit == 0 ||
ntree_limit * model.learner_model_param_->num_output_group >= model.trees.size()) {
auto it = (*cache_).find(dmat);
if (it != cache_->cend()) {
const HostDeviceVector<bst_float>& y = it->second.predictions;
if (y.Size() != 0) {
monitor_.StartCuda("PredictFromCache");
out_preds->SetDevice(y.DeviceIdx());
out_preds->Resize(y.Size());
out_preds->Copy(y);
monitor_.StopCuda("PredictFromCache");
return true;
}
}
}
return false;
}
void UpdatePredictionCache( void UpdatePredictionCache(
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
std::vector<std::unique_ptr<TreeUpdater>>* updaters, std::vector<std::unique_ptr<TreeUpdater>>* updaters,
int num_new_trees) override { int num_new_trees,
DMatrix* m,
PredictionCacheEntry* predts) override {
int device = generic_param_->gpu_id;
ConfigureDevice(device);
auto old_ntree = model.trees.size() - num_new_trees; auto old_ntree = model.trees.size() - num_new_trees;
// update cache entry // update cache entry
for (auto& kv : (*cache_)) { auto* out = &predts->predictions;
PredictionCacheEntry& e = kv.second; if (predts->predictions.Size() == 0) {
DMatrix* dmat = kv.first; InitOutPredictions(m->Info(), out, model);
HostDeviceVector<bst_float>& predictions = e.predictions; DevicePredictInternal(m, out, model, 0, model.trees.size());
} else if (model.learner_model_param_->num_output_group == 1 &&
if (predictions.Size() == 0) { updaters->size() > 0 &&
this->InitOutPredictions(dmat->Info(), &predictions, model);
}
if (model.learner_model_param_->num_output_group == 1 && updaters->size() > 0 &&
num_new_trees == 1 && num_new_trees == 1 &&
updaters->back()->UpdatePredictionCache(e.data.get(), &predictions)) { updaters->back()->UpdatePredictionCache(m, out)) {
// do nothing {}
} else { } else {
DevicePredictInternal(dmat, &predictions, model, old_ntree, model.trees.size()); DevicePredictInternal(m, out, model, old_ntree, model.trees.size());
}
} }
auto delta = num_new_trees / model.learner_model_param_->num_output_group;
predts->Update(delta);
} }
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst& inst,
@ -442,11 +432,6 @@ class GPUPredictor : public xgboost::Predictor {
void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override { void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
Predictor::Configure(cfg); Predictor::Configure(cfg);
int device = generic_param_->gpu_id;
if (device >= 0) {
ConfigureDevice(device);
}
} }
private: private:
@ -469,9 +454,8 @@ class GPUPredictor : public xgboost::Predictor {
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor") XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
.describe("Make predictions using GPU.") .describe("Make predictions using GPU.")
.set_body([](GenericParameter const* generic_param, .set_body([](GenericParameter const* generic_param) {
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache) { return new GPUPredictor(generic_param);
return new GPUPredictor(generic_param, cache);
}); });
} // namespace predictor } // namespace predictor

View File

@ -1,24 +1,60 @@
/*! /*!
* Copyright by Contributors 2017 * Copyright 2017-2020 by Contributors
*/ */
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <xgboost/predictor.h> #include <xgboost/predictor.h>
#include "xgboost/data.h"
#include "xgboost/generic_parameters.h"
namespace dmlc { namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg); DMLC_REGISTRY_ENABLE(::xgboost::PredictorReg);
} // namespace dmlc } // namespace dmlc
namespace xgboost { namespace xgboost {
void PredictionContainer::ClearExpiredEntries() {
std::vector<DMatrix*> expired;
for (auto& kv : container_) {
if (kv.second.ref.expired()) {
expired.emplace_back(kv.first);
}
}
for (auto const& ptr : expired) {
container_.erase(ptr);
}
}
PredictionCacheEntry &PredictionContainer::Cache(std::shared_ptr<DMatrix> m, int32_t device) {
this->ClearExpiredEntries();
container_[m.get()].ref = m;
if (device != GenericParameter::kCpuId) {
container_[m.get()].predictions.SetDevice(device);
}
return container_[m.get()];
}
PredictionCacheEntry &PredictionContainer::Entry(DMatrix *m) {
CHECK(container_.find(m) != container_.cend());
CHECK(container_.at(m).ref.lock())
<< "[Internal error]: DMatrix: " << m << " has expired.";
return container_.at(m);
}
decltype(PredictionContainer::container_) const& PredictionContainer::Container() {
this->ClearExpiredEntries();
return container_;
}
void Predictor::Configure( void Predictor::Configure(
const std::vector<std::pair<std::string, std::string>>& cfg) { const std::vector<std::pair<std::string, std::string>>& cfg) {
} }
Predictor* Predictor::Create( Predictor* Predictor::Create(
std::string const& name, GenericParameter const* generic_param, std::string const& name, GenericParameter const* generic_param) {
std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache) {
auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name); auto* e = ::dmlc::Registry<PredictorReg>::Get()->Find(name);
if (e == nullptr) { if (e == nullptr) {
LOG(FATAL) << "Unknown predictor type " << name; LOG(FATAL) << "Unknown predictor type " << name;
} }
auto p_predictor = (e->body)(generic_param, cache); auto p_predictor = (e->body)(generic_param);
return p_predictor; return p_predictor;
} }
} // namespace xgboost } // namespace xgboost

View File

@ -10,6 +10,7 @@
#include "xgboost/learner.h" #include "xgboost/learner.h"
#include "../helpers.h" #include "../helpers.h"
#include "../../../src/gbm/gbtree.h" #include "../../../src/gbm/gbtree.h"
#include "xgboost/predictor.h"
namespace xgboost { namespace xgboost {
TEST(GBTree, SelectTreeMethod) { TEST(GBTree, SelectTreeMethod) {
@ -22,9 +23,8 @@ TEST(GBTree, SelectTreeMethod) {
mparam.num_feature = kCols; mparam.num_feature = kCols;
mparam.num_output_group = 1; mparam.num_output_group = 1;
std::vector<std::shared_ptr<DMatrix> > caches;
std::unique_ptr<GradientBooster> p_gbm { std::unique_ptr<GradientBooster> p_gbm {
GradientBooster::Create("gbtree", &generic_param, &mparam, caches)}; GradientBooster::Create("gbtree", &generic_param, &mparam)};
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm); auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
// Test if `tree_method` can be set // Test if `tree_method` can be set

View File

@ -1,8 +1,11 @@
/*! /*!
* Copyright 2016-2019 XGBoost contributors * Copyright 2016-2020 XGBoost contributors
*/ */
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <xgboost/metric.h>
#include <xgboost/learner.h>
#include <xgboost/gbm.h> #include <xgboost/gbm.h>
#include <xgboost/json.h> #include <xgboost/json.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -16,6 +19,7 @@
#include "../../src/data/simple_csr_source.h" #include "../../src/data/simple_csr_source.h"
#include "../../src/gbm/gbtree_model.h" #include "../../src/gbm/gbtree_model.h"
#include "xgboost/predictor.h"
bool FileExists(const std::string& filename) { bool FileExists(const std::string& filename) {
struct stat st; struct stat st;
@ -265,13 +269,19 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
} }
} }
gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param) { gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes) {
gbm::GBTreeModel model(param);
for (size_t i = 0; i < n_classes; ++i) {
std::vector<std::unique_ptr<RegTree>> trees; std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree)); trees.push_back(std::unique_ptr<RegTree>(new RegTree));
if (i == 0) {
(*trees.back())[0].SetLeaf(1.5f); (*trees.back())[0].SetLeaf(1.5f);
(*trees.back()).Stat(0).sum_hess = 1.0f; (*trees.back()).Stat(0).sum_hess = 1.0f;
gbm::GBTreeModel model(param); }
model.CommitModel(std::move(trees), 0); model.CommitModel(std::move(trees), i);
}
return model; return model;
} }
@ -279,8 +289,9 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(
std::string name, Args kwargs, size_t kRows, size_t kCols, std::string name, Args kwargs, size_t kRows, size_t kCols,
LearnerModelParam const* learner_model_param, LearnerModelParam const* learner_model_param,
GenericParameter const* generic_param) { GenericParameter const* generic_param) {
auto caches = std::make_shared< PredictionContainer >();;
std::unique_ptr<GradientBooster> gbm { std::unique_ptr<GradientBooster> gbm {
GradientBooster::Create(name, generic_param, learner_model_param, {})}; GradientBooster::Create(name, generic_param, learner_model_param)};
gbm->Configure(kwargs); gbm->Configure(kwargs);
auto pp_dmat = CreateDMatrix(kRows, kCols, 0); auto pp_dmat = CreateDMatrix(kRows, kCols, 0);
auto p_dmat = *pp_dmat; auto p_dmat = *pp_dmat;
@ -297,7 +308,9 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(
h_gpair[i] = {static_cast<float>(i), 1}; h_gpair[i] = {static_cast<float>(i), 1};
} }
gbm->DoBoost(p_dmat.get(), &gpair, nullptr); PredictionCacheEntry predts;
gbm->DoBoost(p_dmat.get(), &gpair, &predts);
delete pp_dmat; delete pp_dmat;
return gbm; return gbm;

View File

@ -16,16 +16,13 @@
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/objective.h>
#include <xgboost/metric.h>
#include <xgboost/json.h> #include <xgboost/json.h>
#include <xgboost/predictor.h>
#include <xgboost/generic_parameters.h> #include <xgboost/generic_parameters.h>
#include <xgboost/c_api.h> #include <xgboost/c_api.h>
#include <xgboost/learner.h>
#include "../../src/common/common.h" #include "../../src/common/common.h"
#include "../../src/common/hist_util.h" #include "../../src/common/hist_util.h"
#include "../../src/gbm/gbtree_model.h"
#if defined(__CUDACC__) #if defined(__CUDACC__)
#include "../../src/data/ellpack_page.cuh" #include "../../src/data/ellpack_page.cuh"
#endif #endif
@ -42,6 +39,12 @@
#define GPUIDX -1 #define GPUIDX -1
#endif #endif
namespace xgboost {
class ObjFunction;
class Metric;
struct LearnerModelParam;
}
bool FileExists(const std::string& filename); bool FileExists(const std::string& filename);
int64_t GetFileSize(const std::string& filename); int64_t GetFileSize(const std::string& filename);
@ -206,7 +209,7 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
size_t n_rows, size_t n_cols, size_t page_size, bool deterministic, size_t n_rows, size_t n_cols, size_t page_size, bool deterministic,
const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory()); const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory());
gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param); gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes = 1);
std::unique_ptr<GradientBooster> CreateTrainedGBM( std::unique_ptr<GradientBooster> CreateTrainedGBM(
std::string name, Args kwargs, size_t kRows, size_t kCols, std::string name, Args kwargs, size_t kRows, size_t kCols,

View File

@ -1,4 +1,6 @@
// Copyright by Contributors /*!
* Copyright 2017-2020 XGBoost contributors
*/
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/predictor.h> #include <xgboost/predictor.h>
@ -9,9 +11,8 @@
namespace xgboost { namespace xgboost {
TEST(CpuPredictor, Basic) { TEST(CpuPredictor, Basic) {
auto lparam = CreateEmptyGenericParam(GPUIDX); auto lparam = CreateEmptyGenericParam(GPUIDX);
auto cache = std::make_shared<std::unordered_map<DMatrix*, PredictionCacheEntry>>();
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam, cache)); std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
int kRows = 5; int kRows = 5;
int kCols = 5; int kCols = 5;
@ -26,10 +27,11 @@ TEST(CpuPredictor, Basic) {
auto dmat = CreateDMatrix(kRows, kCols, 0); auto dmat = CreateDMatrix(kRows, kCols, 0);
// Test predict batch // Test predict batch
HostDeviceVector<float> out_predictions; PredictionCacheEntry out_predictions;
cpu_predictor->PredictBatch((*dmat).get(), &out_predictions, model, 0); cpu_predictor->PredictBatch((*dmat).get(), &out_predictions, model, 0);
std::vector<float>& out_predictions_h = out_predictions.HostVector(); ASSERT_EQ(model.trees.size(), out_predictions.version);
for (size_t i = 0; i < out_predictions.Size(); i++) { std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
for (size_t i = 0; i < out_predictions.predictions.Size(); i++) {
ASSERT_EQ(out_predictions_h[i], 1.5); ASSERT_EQ(out_predictions_h[i], 1.5);
} }
@ -81,10 +83,9 @@ TEST(CpuPredictor, ExternalMemory) {
std::string filename = tmpdir.path + "/big.libsvm"; std::string filename = tmpdir.path + "/big.libsvm";
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename); std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename);
auto lparam = CreateEmptyGenericParam(GPUIDX); auto lparam = CreateEmptyGenericParam(GPUIDX);
auto cache = std::make_shared<std::unordered_map<DMatrix*, PredictionCacheEntry>>();
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam, cache)); std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
LearnerModelParam param; LearnerModelParam param;
param.base_score = 0; param.base_score = 0;
@ -94,10 +95,10 @@ TEST(CpuPredictor, ExternalMemory) {
gbm::GBTreeModel model = CreateTestModel(&param); gbm::GBTreeModel model = CreateTestModel(&param);
// Test predict batch // Test predict batch
HostDeviceVector<float> out_predictions; PredictionCacheEntry out_predictions;
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
std::vector<float> &out_predictions_h = out_predictions.HostVector(); std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
ASSERT_EQ(out_predictions.Size(), dmat->Info().num_row_); ASSERT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_);
for (const auto& v : out_predictions_h) { for (const auto& v : out_predictions_h) {
ASSERT_EQ(v, 1.5); ASSERT_EQ(v, 1.5);
} }

View File

@ -1,6 +1,5 @@
/*! /*!
* Copyright 2017-2019 XGBoost contributors * Copyright 2017-2020 XGBoost contributors
*/ */
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <xgboost/c_api.h> #include <xgboost/c_api.h>
@ -19,12 +18,11 @@ namespace predictor {
TEST(GpuPredictor, Basic) { TEST(GpuPredictor, Basic) {
auto cpu_lparam = CreateEmptyGenericParam(-1); auto cpu_lparam = CreateEmptyGenericParam(-1);
auto gpu_lparam = CreateEmptyGenericParam(0); auto gpu_lparam = CreateEmptyGenericParam(0);
auto cache = std::make_shared<std::unordered_map<DMatrix*, PredictionCacheEntry>>();
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam, cache)); std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam));
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &cpu_lparam, cache)); std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &cpu_lparam));
gpu_predictor->Configure({}); gpu_predictor->Configure({});
cpu_predictor->Configure({}); cpu_predictor->Configure({});
@ -41,16 +39,17 @@ TEST(GpuPredictor, Basic) {
gbm::GBTreeModel model = CreateTestModel(&param); gbm::GBTreeModel model = CreateTestModel(&param);
// Test predict batch // Test predict batch
HostDeviceVector<float> gpu_out_predictions; PredictionCacheEntry gpu_out_predictions;
HostDeviceVector<float> cpu_out_predictions; PredictionCacheEntry cpu_out_predictions;
gpu_predictor->PredictBatch((*dmat).get(), &gpu_out_predictions, model, 0); gpu_predictor->PredictBatch((*dmat).get(), &gpu_out_predictions, model, 0);
ASSERT_EQ(model.trees.size(), gpu_out_predictions.version);
cpu_predictor->PredictBatch((*dmat).get(), &cpu_out_predictions, model, 0); cpu_predictor->PredictBatch((*dmat).get(), &cpu_out_predictions, model, 0);
std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.HostVector(); std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.predictions.HostVector();
std::vector<float>& cpu_out_predictions_h = cpu_out_predictions.HostVector(); std::vector<float>& cpu_out_predictions_h = cpu_out_predictions.predictions.HostVector();
float abs_tolerance = 0.001; float abs_tolerance = 0.001;
for (int j = 0; j < gpu_out_predictions.Size(); j++) { for (int j = 0; j < gpu_out_predictions.predictions.Size(); j++) {
ASSERT_NEAR(gpu_out_predictions_h[j], cpu_out_predictions_h[j], abs_tolerance); ASSERT_NEAR(gpu_out_predictions_h[j], cpu_out_predictions_h[j], abs_tolerance);
} }
delete dmat; delete dmat;
@ -59,9 +58,8 @@ TEST(GpuPredictor, Basic) {
TEST(gpu_predictor, ExternalMemoryTest) { TEST(gpu_predictor, ExternalMemoryTest) {
auto lparam = CreateEmptyGenericParam(0); auto lparam = CreateEmptyGenericParam(0);
auto cache = std::make_shared<std::unordered_map<DMatrix*, PredictionCacheEntry>>();
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam, cache)); std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
gpu_predictor->Configure({}); gpu_predictor->Configure({});
LearnerModelParam param; LearnerModelParam param;
@ -70,7 +68,7 @@ TEST(gpu_predictor, ExternalMemoryTest) {
param.num_output_group = n_classes; param.num_output_group = n_classes;
param.base_score = 0.5; param.base_score = 0.5;
gbm::GBTreeModel model = CreateTestModel(&param); gbm::GBTreeModel model = CreateTestModel(&param, n_classes);
std::vector<std::unique_ptr<DMatrix>> dmats; std::vector<std::unique_ptr<DMatrix>> dmats;
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string file0 = tmpdir.path + "/big_0.libsvm"; std::string file0 = tmpdir.path + "/big_0.libsvm";
@ -82,10 +80,10 @@ TEST(gpu_predictor, ExternalMemoryTest) {
for (const auto& dmat: dmats) { for (const auto& dmat: dmats) {
dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5); dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5);
HostDeviceVector<float> out_predictions; PredictionCacheEntry out_predictions;
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_ * n_classes); EXPECT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_ * n_classes);
const std::vector<float> &host_vector = out_predictions.ConstHostVector(); const std::vector<float> &host_vector = out_predictions.predictions.ConstHostVector();
for (int i = 0; i < host_vector.size() / n_classes; i++) { for (int i = 0; i < host_vector.size() / n_classes; i++) {
ASSERT_EQ(host_vector[i * n_classes], 2.0); ASSERT_EQ(host_vector[i * n_classes], 2.0);
ASSERT_EQ(host_vector[i * n_classes + 1], 0.5); ASSERT_EQ(host_vector[i * n_classes + 1], 0.5);

View File

@ -0,0 +1,33 @@
/*!
* Copyright 2020 by Contributors
*/
#include <cstddef>
#include <gtest/gtest.h>
#include <xgboost/predictor.h>
#include <xgboost/data.h>
#include "../helpers.h"
#include "xgboost/generic_parameters.h"
namespace xgboost {
TEST(Predictor, PredictionCache) {
size_t constexpr kRows = 16, kCols = 4;
PredictionContainer container;
DMatrix* m;
// Add a cache that is immediately expired.
auto add_cache = [&]() {
auto *pp_dmat = CreateDMatrix(kRows, kCols, 0);
auto p_dmat = *pp_dmat;
container.Cache(p_dmat, GenericParameter::kCpuId);
m = p_dmat.get();
delete pp_dmat;
};
add_cache();
ASSERT_EQ(container.Container().size(), 0);
add_cache();
EXPECT_ANY_THROW(container.Entry(m));
}
} // namespace xgboost

View File

@ -19,11 +19,12 @@ rng = np.random.RandomState(1994)
@contextmanager @contextmanager
def captured_output(): def captured_output():
""" """Reassign stdout temporarily in order to test printed statements
Reassign stdout temporarily in order to test printed statements Taken from:
Taken from: https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python https://stackoverflow.com/questions/4219717/how-to-assert-output-with-nosetest-unittest-in-python
Also works for pytest. Also works for pytest.
""" """
new_out, new_err = StringIO(), StringIO() new_out, new_err = StringIO(), StringIO()
old_out, old_err = sys.stdout, sys.stderr old_out, old_err = sys.stdout, sys.stderr
@ -42,10 +43,17 @@ class TestBasic(unittest.TestCase):
param = {'max_depth': 2, 'eta': 1, param = {'max_depth': 2, 'eta': 1,
'objective': 'binary:logistic'} 'objective': 'binary:logistic'}
# specify validations set to watch performance # specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')] watchlist = [(dtrain, 'train')]
num_round = 2 num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist) bst = xgb.train(param, dtrain, num_round, watchlist, verbose_eval=True)
# this is prediction
preds = bst.predict(dtrain)
labels = dtrain.get_label()
err = sum(1 for i in range(len(preds))
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
# error must be smaller than 10%
assert err < 0.1
preds = bst.predict(dtest) preds = bst.predict(dtest)
labels = dtest.get_label() labels = dtest.get_label()
err = sum(1 for i in range(len(preds)) err = sum(1 for i in range(len(preds))