[CORE] Refactor cache mechanism (#1540)

This commit is contained in:
Tianqi Chen 2016-09-02 20:39:07 -07:00 committed by GitHub
parent 6dabdd33e3
commit ecec5f7959
9 changed files with 320 additions and 421 deletions

View File

@ -3,6 +3,10 @@ XGBoost Change Log
This file records the changes in xgboost library in reverse chronological order.
## in progress version
* Refactored gbm to allow more friendly cache strategy
- Specialized some prediction routine
## v0.6 (2016.07.29)
* Version 0.5 is skipped due to major improvements in the core
* Major refactor of core library.

View File

@ -13,8 +13,10 @@
#include <utility>
#include <string>
#include <functional>
#include <memory>
#include "./base.h"
#include "./data.h"
#include "./objective.h"
#include "./feature_map.h"
namespace xgboost {
@ -50,13 +52,6 @@ class GradientBooster {
* \param fo output stream
*/
virtual void Save(dmlc::Stream* fo) const = 0;
/*!
* \brief reset the predict buffer size.
* This will invalidate all the previous cached results
* and recalculate from scratch
* \param num_pbuffer The size of predict buffer.
*/
virtual void ResetPredBuffer(size_t num_pbuffer) {}
/*!
* \brief whether the model allow lazy checkpoint
* return true if model is only updated in DoBoost
@ -68,27 +63,21 @@ class GradientBooster {
/*!
* \brief perform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features
* \param buffer_offset buffer index offset of these instances, if equals -1
* this means we do not have buffer index allocated to the gbm
* \param in_gpair address of the gradient pair statistics of the data
* \param obj The objective function, optional, can be nullptr when use customized version
* the booster may change content of gpair
*/
virtual void DoBoost(DMatrix* p_fmat,
int64_t buffer_offset,
std::vector<bst_gpair>* in_gpair) = 0;
std::vector<bst_gpair>* in_gpair,
ObjFunction* obj = nullptr) = 0;
/*!
* \brief generate predictions for given feature matrix
* \param dmat feature matrix
* \param buffer_offset buffer index offset of these instances, if equals -1
* this means we do not have buffer index allocated to the gbm
* a buffer index is assigned to each instance that requires repeative prediction
* the size of buffer is set by convention using GradientBooster.ResetPredBuffer(size);
* \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/
virtual void Predict(DMatrix* dmat,
int64_t buffer_offset,
std::vector<float>* out_preds,
unsigned ntree_limit = 0) = 0;
/*!
@ -128,9 +117,14 @@ class GradientBooster {
/*!
* \brief create a gradient booster from given name
* \param name name of gradient booster
* \param cache_mats The cache data matrix of the Booster.
* \param base_margin The base margin of prediction.
* \return The created booster.
*/
static GradientBooster* Create(const std::string& name);
static GradientBooster* Create(
const std::string& name,
const std::vector<std::shared_ptr<DMatrix> >& cache_mats,
float base_margin);
};
// implementing configure.
@ -144,8 +138,10 @@ inline void GradientBooster::Configure(PairIter begin, PairIter end) {
* \brief Registry entry for tree updater.
*/
struct GradientBoosterReg
: public dmlc::FunctionRegEntryBase<GradientBoosterReg,
std::function<GradientBooster* ()> > {
: public dmlc::FunctionRegEntryBase<
GradientBoosterReg,
std::function<GradientBooster* (const std::vector<std::shared_ptr<DMatrix> > &cached_mats,
float base_margin)> > {
};
/*!

View File

@ -166,7 +166,7 @@ class Learner : public rabit::Serializable {
* \param cache_data The matrix to cache the prediction.
* \return Created learner.
*/
static Learner* Create(const std::vector<DMatrix*>& cache_data);
static Learner* Create(const std::vector<std::shared_ptr<DMatrix> >& cache_data);
protected:
/*! \brief internal base score of the model */

View File

@ -22,7 +22,7 @@ namespace xgboost {
// booster wrapper for backward compatible reason.
class Booster {
public:
explicit Booster(const std::vector<DMatrix*>& cache_mats)
explicit Booster(const std::vector<std::shared_ptr<DMatrix> >& cache_mats)
: configured_(false),
initialized_(false),
learner_(Learner::Create(cache_mats)) {}
@ -207,8 +207,7 @@ int XGDMatrixCreateFromFile(const char *fname,
LOG(CONSOLE) << "XGBoost distributed mode detected, "
<< "will split data among workers";
}
*out = DMatrix::Load(
fname, false, true);
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, false, true));
API_END();
}
@ -224,7 +223,7 @@ int XGDMatrixCreateFromDataIter(
scache = cache_info;
}
NativeDataIter parser(data_handle, callback);
*out = DMatrix::Create(&parser, scache);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&parser, scache));
API_END();
}
@ -250,16 +249,16 @@ XGB_DLL int XGDMatrixCreateFromCSR(const xgboost::bst_ulong* indptr,
}
mat.info.num_row = nindptr - 1;
mat.info.num_nonzero = static_cast<uint64_t>(nelem);
*out = DMatrix::Create(std::move(source));
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
XGB_DLL int XGDMatrixCreateFromCSC(const xgboost::bst_ulong* col_ptr,
const unsigned* indices,
const float* data,
xgboost::bst_ulong nindptr,
xgboost::bst_ulong nelem,
DMatrixHandle* out) {
const unsigned* indices,
const float* data,
xgboost::bst_ulong nindptr,
xgboost::bst_ulong nelem,
DMatrixHandle* out) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
API_BEGIN();
@ -292,15 +291,15 @@ XGB_DLL int XGDMatrixCreateFromCSC(const xgboost::bst_ulong* col_ptr,
mat.info.num_row = mat.row_ptr_.size() - 1;
mat.info.num_col = static_cast<uint64_t>(ncol);
mat.info.num_nonzero = nelem;
*out = DMatrix::Create(std::move(source));
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
XGB_DLL int XGDMatrixCreateFromMat(const float* data,
xgboost::bst_ulong nrow,
xgboost::bst_ulong ncol,
float missing,
DMatrixHandle* out) {
xgboost::bst_ulong nrow,
xgboost::bst_ulong ncol,
float missing,
DMatrixHandle* out) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
API_BEGIN();
@ -324,19 +323,19 @@ XGB_DLL int XGDMatrixCreateFromMat(const float* data,
mat.row_ptr_.push_back(mat.row_ptr_.back() + nelem);
}
mat.info.num_nonzero = mat.row_data_.size();
*out = DMatrix::Create(std::move(source));
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
const int* idxset,
xgboost::bst_ulong len,
DMatrixHandle* out) {
const int* idxset,
xgboost::bst_ulong len,
DMatrixHandle* out) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
API_BEGIN();
data::SimpleCSRSource src;
src.CopyFrom(static_cast<DMatrix*>(handle));
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
data::SimpleCSRSource& ret = *source;
CHECK_EQ(src.info.group_ptr.size(), 0)
@ -371,21 +370,21 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
ret.info.root_index.push_back(src.info.root_index[ridx]);
}
}
*out = DMatrix::Create(std::move(source));
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
XGB_DLL int XGDMatrixFree(DMatrixHandle handle) {
API_BEGIN();
delete static_cast<DMatrix*>(handle);
delete static_cast<std::shared_ptr<DMatrix>*>(handle);
API_END();
}
XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
const char* fname,
int silent) {
const char* fname,
int silent) {
API_BEGIN();
static_cast<DMatrix*>(handle)->SaveToLocalFile(fname);
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->SaveToLocalFile(fname);
API_END();
}
@ -394,7 +393,8 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const float* info,
xgboost::bst_ulong len) {
API_BEGIN();
static_cast<DMatrix*>(handle)->info().SetInfo(field, info, kFloat32, len);
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->info().SetInfo(field, info, kFloat32, len);
API_END();
}
@ -403,16 +403,17 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const unsigned* info,
xgboost::bst_ulong len) {
API_BEGIN();
static_cast<DMatrix*>(handle)->info().SetInfo(field, info, kUInt32, len);
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->info().SetInfo(field, info, kUInt32, len);
API_END();
}
XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned* group,
xgboost::bst_ulong len) {
const unsigned* group,
xgboost::bst_ulong len) {
API_BEGIN();
DMatrix *pmat = static_cast<DMatrix*>(handle);
MetaInfo& info = pmat->info();
std::shared_ptr<DMatrix> *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle);
MetaInfo& info = pmat->get()->info();
info.group_ptr.resize(len + 1);
info.group_ptr[0] = 0;
for (uint64_t i = 0; i < len; ++i) {
@ -422,11 +423,11 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
}
XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
const char* field,
xgboost::bst_ulong* out_len,
const float** out_dptr) {
const char* field,
xgboost::bst_ulong* out_len,
const float** out_dptr) {
API_BEGIN();
const MetaInfo& info = static_cast<const DMatrix*>(handle)->info();
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->info();
const std::vector<float>* vec = nullptr;
if (!std::strcmp(field, "label")) {
vec = &info.labels;
@ -443,11 +444,11 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
}
XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
const char *field,
xgboost::bst_ulong *out_len,
const unsigned **out_dptr) {
const char *field,
xgboost::bst_ulong *out_len,
const unsigned **out_dptr) {
API_BEGIN();
const MetaInfo& info = static_cast<const DMatrix*>(handle)->info();
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->info();
const std::vector<unsigned>* vec = nullptr;
if (!std::strcmp(field, "root_index")) {
vec = &info.root_index;
@ -460,16 +461,18 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
}
XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
xgboost::bst_ulong *out) {
xgboost::bst_ulong *out) {
API_BEGIN();
*out = static_cast<xgboost::bst_ulong>(static_cast<const DMatrix*>(handle)->info().num_row);
*out = static_cast<xgboost::bst_ulong>(
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->info().num_row);
API_END();
}
XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
xgboost::bst_ulong *out) {
xgboost::bst_ulong *out) {
API_BEGIN();
*out = static_cast<size_t>(static_cast<const DMatrix*>(handle)->info().num_col);
*out = static_cast<size_t>(
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->info().num_col);
API_END();
}
@ -478,9 +481,9 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
xgboost::bst_ulong len,
BoosterHandle *out) {
API_BEGIN();
std::vector<DMatrix*> mats;
std::vector<std::shared_ptr<DMatrix> > mats;
for (xgboost::bst_ulong i = 0; i < len; ++i) {
mats.push_back(static_cast<DMatrix*>(dmats[i]));
mats.push_back(*static_cast<std::shared_ptr<DMatrix>*>(dmats[i]));
}
*out = new Booster(mats);
API_END();
@ -493,50 +496,52 @@ XGB_DLL int XGBoosterFree(BoosterHandle handle) {
}
XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
const char *name,
const char *value) {
const char *name,
const char *value) {
API_BEGIN();
static_cast<Booster*>(handle)->SetParam(name, value);
API_END();
}
XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
int iter,
DMatrixHandle dtrain) {
int iter,
DMatrixHandle dtrain) {
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
DMatrix *dtr = static_cast<DMatrix*>(dtrain);
std::shared_ptr<DMatrix> *dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
bst->LazyInit();
bst->learner()->UpdateOneIter(iter, dtr);
bst->learner()->UpdateOneIter(iter, dtr->get());
API_END();
}
XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
DMatrixHandle dtrain,
float *grad,
float *hess,
xgboost::bst_ulong len) {
DMatrixHandle dtrain,
float *grad,
float *hess,
xgboost::bst_ulong len) {
std::vector<bst_gpair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
DMatrix* dtr = static_cast<DMatrix*>(dtrain);
std::shared_ptr<DMatrix>* dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
tmp_gpair.resize(len);
for (xgboost::bst_ulong i = 0; i < len; ++i) {
tmp_gpair[i] = bst_gpair(grad[i], hess[i]);
}
bst->LazyInit();
bst->learner()->BoostOneIter(0, dtr, &tmp_gpair);
bst->learner()->BoostOneIter(0, dtr->get(), &tmp_gpair);
API_END();
}
XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
int iter,
DMatrixHandle dmats[],
const char* evnames[],
xgboost::bst_ulong len,
const char** out_str) {
int iter,
DMatrixHandle dmats[],
const char* evnames[],
xgboost::bst_ulong len,
const char** out_str) {
std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str;
API_BEGIN();
Booster* bst = static_cast<Booster*>(handle);
@ -544,7 +549,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
std::vector<std::string> data_names;
for (xgboost::bst_ulong i = 0; i < len; ++i) {
data_sets.push_back(static_cast<DMatrix*>(dmats[i]));
data_sets.push_back(static_cast<std::shared_ptr<DMatrix>*>(dmats[i])->get());
data_names.push_back(std::string(evnames[i]));
}
@ -555,17 +560,17 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
}
XGB_DLL int XGBoosterPredict(BoosterHandle handle,
DMatrixHandle dmat,
int option_mask,
unsigned ntree_limit,
xgboost::bst_ulong *len,
const float **out_result) {
DMatrixHandle dmat,
int option_mask,
unsigned ntree_limit,
xgboost::bst_ulong *len,
const float **out_result) {
std::vector<float>& preds = XGBAPIThreadLocalStore::Get()->ret_vec_float;
API_BEGIN();
Booster *bst = static_cast<Booster*>(handle);
bst->LazyInit();
bst->learner()->Predict(
static_cast<DMatrix*>(dmat),
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
(option_mask & 1) != 0,
&preds, ntree_limit,
(option_mask & 2) != 0);

View File

@ -156,16 +156,18 @@ void CLITrain(const CLIParam& param) {
LOG(CONSOLE) << "start " << pname << ":" << rabit::GetRank();
}
// load in data.
std::unique_ptr<DMatrix> dtrain(
std::shared_ptr<DMatrix> dtrain(
DMatrix::Load(param.train_path, param.silent != 0, param.dsplit == 2));
std::vector<std::unique_ptr<DMatrix> > deval;
std::vector<DMatrix*> cache_mats, eval_datasets;
cache_mats.push_back(dtrain.get());
std::vector<std::shared_ptr<DMatrix> > deval;
std::vector<std::shared_ptr<DMatrix> > cache_mats;
std::vector<DMatrix*> eval_datasets;
cache_mats.push_back(dtrain);
for (size_t i = 0; i < param.eval_data_names.size(); ++i) {
deval.emplace_back(
DMatrix::Load(param.eval_data_paths[i], param.silent != 0, param.dsplit == 2));
std::shared_ptr<DMatrix>(DMatrix::Load(param.eval_data_paths[i],
param.silent != 0, param.dsplit == 2)));
eval_datasets.push_back(deval.back().get());
cache_mats.push_back(deval.back().get());
cache_mats.push_back(deval.back());
}
std::vector<std::string> eval_data_names = param.eval_data_names;
if (param.eval_train) {

View File

@ -87,6 +87,9 @@ struct GBLinearTrainParam : public dmlc::Parameter<GBLinearTrainParam> {
*/
class GBLinear : public GradientBooster {
public:
explicit GBLinear(float base_margin)
: base_margin_(base_margin) {
}
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
if (model.weight.size() == 0) {
model.param.InitAllowUnknown(cfg);
@ -99,9 +102,9 @@ class GBLinear : public GradientBooster {
void Save(dmlc::Stream* fo) const override {
model.Save(fo);
}
virtual void DoBoost(DMatrix *p_fmat,
int64_t buffer_offset,
std::vector<bst_gpair> *in_gpair) {
void DoBoost(DMatrix *p_fmat,
std::vector<bst_gpair> *in_gpair,
ObjFunction* obj) override {
// lazily initialize the model when not ready.
if (model.weight.size() == 0) {
model.InitModel();
@ -168,7 +171,6 @@ class GBLinear : public GradientBooster {
}
void Predict(DMatrix *p_fmat,
int64_t buffer_offset,
std::vector<float> *out_preds,
unsigned ntree_limit) override {
if (model.weight.size() == 0) {
@ -177,6 +179,11 @@ class GBLinear : public GradientBooster {
CHECK_EQ(ntree_limit, 0)
<< "GBLinear::Predict ntrees is only valid for gbtree predictor";
std::vector<float> &preds = *out_preds;
const std::vector<bst_float>& base_margin = p_fmat->info().base_margin;
if (base_margin.size() != 0) {
CHECK_EQ(preds.size(), base_margin.size())
<< "base_margin.size does not match with prediction size";
}
preds.resize(0);
// start collecting the prediction
dmlc::DataIter<RowBatch> *iter = p_fmat->RowIterator();
@ -188,24 +195,27 @@ class GBLinear : public GradientBooster {
// k is number of group
preds.resize(preds.size() + batch.size * ngroup);
// parallel over local batch
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
const omp_ulong nsize = static_cast<omp_ulong>(batch.size);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
for (omp_ulong i = 0; i < nsize; ++i) {
const size_t ridx = batch.base_rowid + i;
// loop over output groups
for (int gid = 0; gid < ngroup; ++gid) {
this->Pred(batch[i], &preds[ridx * ngroup], gid);
float margin = (base_margin.size() != 0) ?
base_margin[ridx * ngroup + gid] : base_margin_;
this->Pred(batch[i], &preds[ridx * ngroup], gid, margin);
}
}
}
}
// add base margin
void Predict(const SparseBatch::Inst &inst,
std::vector<float> *out_preds,
unsigned ntree_limit,
unsigned root_index) override {
const int ngroup = model.param.num_output_group;
for (int gid = 0; gid < ngroup; ++gid) {
this->Pred(inst, dmlc::BeginPtr(*out_preds), gid);
this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, base_margin_);
}
}
void PredictLeaf(DMatrix *p_fmat,
@ -232,8 +242,8 @@ class GBLinear : public GradientBooster {
}
protected:
inline void Pred(const RowBatch::Inst &inst, float *preds, int gid) {
float psum = model.bias()[gid];
inline void Pred(const RowBatch::Inst &inst, float *preds, int gid, float base) {
float psum = model.bias()[gid] + base;
for (bst_uint i = 0; i < inst.length; ++i) {
if (inst[i].index >= model.param.num_feature) continue;
psum += inst[i].fvalue * model[inst[i].index][gid];
@ -278,6 +288,8 @@ class GBLinear : public GradientBooster {
return &weight[i * param.num_output_group];
}
};
// biase margin score
float base_margin_;
// model field
Model model;
// training parameter
@ -292,8 +304,8 @@ DMLC_REGISTER_PARAMETER(GBLinearTrainParam);
XGBOOST_REGISTER_GBM(GBLinear, "gblinear")
.describe("Linear booster, implement generalized linear model.")
.set_body([]() {
return new GBLinear();
.set_body([](const std::vector<std::shared_ptr<DMatrix> >&cache, float base_margin) {
return new GBLinear(base_margin);
});
} // namespace gbm
} // namespace xgboost

View File

@ -11,12 +11,15 @@ DMLC_REGISTRY_ENABLE(::xgboost::GradientBoosterReg);
} // namespace dmlc
namespace xgboost {
GradientBooster* GradientBooster::Create(const std::string& name) {
GradientBooster* GradientBooster::Create(
const std::string& name,
const std::vector<std::shared_ptr<DMatrix> >& cache_mats,
float base_margin) {
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown gbm type " << name;
}
return (e->body)();
return (e->body)(cache_mats, base_margin);
}
} // namespace xgboost

View File

@ -15,6 +15,7 @@
#include <utility>
#include <string>
#include <limits>
#include <algorithm>
#include "../common/common.h"
#include "../common/random.h"
@ -123,10 +124,24 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
}
};
// cache entry
struct CacheEntry {
std::shared_ptr<DMatrix> data;
std::vector<float> predictions;
};
// gradient boosted trees
class GBTree : public GradientBooster {
public:
GBTree() : num_pbuffer(0) {}
explicit GBTree(float base_margin) : base_margin_(base_margin) {}
void InitCache(const std::vector<std::shared_ptr<DMatrix> > &cache) {
for (const std::shared_ptr<DMatrix>& d : cache) {
CacheEntry e;
e.data = d;
cache_[d.get()] = std::move(e);
}
}
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
this->cfg = cfg;
@ -160,8 +175,6 @@ class GBTree : public GradientBooster {
this->cfg.clear();
this->cfg.push_back(std::make_pair(std::string("num_feature"),
common::ToString(mparam.num_feature)));
// clear the predict buffer.
this->ResetPredBuffer(num_pbuffer);
}
void Save(dmlc::Stream* fo) const override {
@ -175,27 +188,19 @@ class GBTree : public GradientBooster {
}
}
void ResetPredBuffer(size_t num_pbuffer) override {
this->num_pbuffer = num_pbuffer;
pred_buffer.clear();
pred_counter.clear();
pred_buffer.resize(this->PredBufferSize(), 0.0f);
pred_counter.resize(this->PredBufferSize(), 0);
}
bool AllowLazyCheckPoint() const override {
return mparam.num_output_group == 1 ||
tparam.updater_seq.find("distcol") != std::string::npos;
}
void DoBoost(DMatrix* p_fmat,
int64_t buffer_offset,
std::vector<bst_gpair>* in_gpair) override {
std::vector<bst_gpair>* in_gpair,
ObjFunction* obj) override {
const std::vector<bst_gpair>& gpair = *in_gpair;
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
if (mparam.num_output_group == 1) {
std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(gpair, p_fmat, buffer_offset, 0, &ret);
BoostNewTrees(gpair, p_fmat, 0, &ret);
new_trees.push_back(std::move(ret));
} else {
const int ngroup = mparam.num_output_group;
@ -209,7 +214,7 @@ class GBTree : public GradientBooster {
tmp[i] = gpair[i * ngroup + gid];
}
std::vector<std::unique_ptr<RegTree> > ret;
BoostNewTrees(tmp, p_fmat, buffer_offset, gid, &ret);
BoostNewTrees(tmp, p_fmat, gid, &ret);
new_trees.push_back(std::move(ret));
}
}
@ -219,48 +224,21 @@ class GBTree : public GradientBooster {
}
void Predict(DMatrix* p_fmat,
int64_t buffer_offset,
std::vector<float>* out_preds,
unsigned ntree_limit) override {
const MetaInfo& info = p_fmat->info();
int nthread;
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
InitThreadTemp(nthread);
std::vector<float> &preds = *out_preds;
const size_t stride = p_fmat->info().num_row * mparam.num_output_group;
preds.resize(stride * (mparam.size_leaf_vector+1));
// start collecting the prediction
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
// parallel over local batch
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
int ridx_error = 0;
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
const int tid = omp_get_thread_num();
RegTree::FVec &feats = thread_temp[tid];
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
if (static_cast<size_t>(ridx) >= info.num_row) {
ridx_error = 1;
continue;
}
// loop over output groups
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
this->Pred(batch[i],
buffer_offset < 0 ? -1 : buffer_offset + ridx,
gid, info.GetRoot(ridx), &feats,
&preds[ridx * mparam.num_output_group + gid], stride,
ntree_limit);
if (ntree_limit == 0 ||
ntree_limit * mparam.num_output_group >= trees.size()) {
auto it = cache_.find(p_fmat);
if (it != cache_.end()) {
std::vector<float>& y = it->second.predictions;
if (y.size() != 0) {
out_preds->resize(y.size());
std::copy(y.begin(), y.end(), out_preds->begin());
return;
}
}
CHECK(!ridx_error) << "ridx out of bounds";
}
PredLoopInternal<GBTree>(p_fmat, out_preds, 0, ntree_limit, true);
}
void Predict(const SparseBatch::Inst& inst,
@ -271,12 +249,16 @@ class GBTree : public GradientBooster {
thread_temp.resize(1, RegTree::FVec());
thread_temp[0].Init(mparam.num_feature);
}
ntree_limit *= mparam.num_output_group;
if (ntree_limit == 0 || ntree_limit > trees.size()) {
ntree_limit = static_cast<unsigned>(trees.size());
}
out_preds->resize(mparam.num_output_group * (mparam.size_leaf_vector+1));
// loop over output groups
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
this->Pred(inst, -1, gid, root_index, &thread_temp[0],
&(*out_preds)[gid], mparam.num_output_group,
ntree_limit);
(*out_preds)[gid] =
PredValue(inst, gid, root_index,
&thread_temp[0], 0, ntree_limit) + base_margin_;
}
}
@ -301,6 +283,84 @@ class GBTree : public GradientBooster {
}
protected:
// internal prediction loop
// add predictions to out_preds
template<typename Derived>
inline void PredLoopInternal(
DMatrix* p_fmat,
std::vector<float>* out_preds,
unsigned tree_begin,
unsigned ntree_limit,
bool init_out_preds) {
int num_group = mparam.num_output_group;
ntree_limit *= num_group;
if (ntree_limit == 0 || ntree_limit > trees.size()) {
ntree_limit = static_cast<unsigned>(trees.size());
}
if (init_out_preds) {
size_t n = num_group * p_fmat->info().num_row;
const std::vector<float>& base_margin = p_fmat->info().base_margin;
out_preds->resize(n);
if (base_margin.size() != 0) {
CHECK_EQ(out_preds->size(), n);
std::copy(base_margin.begin(), base_margin.end(), out_preds->begin());
} else {
std::fill(out_preds->begin(), out_preds->end(), base_margin_);
}
}
if (num_group == 1) {
PredLoopSpecalize<Derived>(p_fmat, out_preds, 1,
tree_begin, ntree_limit);
} else {
PredLoopSpecalize<Derived>(p_fmat, out_preds, num_group,
tree_begin, ntree_limit);
}
}
template<typename Derived>
inline void PredLoopSpecalize(
DMatrix* p_fmat,
std::vector<float>* out_preds,
int num_group,
unsigned tree_begin,
unsigned tree_end) {
const MetaInfo& info = p_fmat->info();
int nthread;
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
CHECK_EQ(num_group, mparam.num_output_group);
InitThreadTemp(nthread);
std::vector<float> &preds = *out_preds;
CHECK_EQ(mparam.size_leaf_vector, 0)
<< "size_leaf_vector is enforced to 0 so far";
CHECK_EQ(preds.size(), p_fmat->info().num_row * num_group);
// start collecting the prediction
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
Derived* self = static_cast<Derived*>(this);
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
// parallel over local batch
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
const int tid = omp_get_thread_num();
RegTree::FVec &feats = thread_temp[tid];
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
CHECK_LT(static_cast<size_t>(ridx), info.num_row);
for (int gid = 0; gid < num_group; ++gid) {
size_t offset = ridx * num_group + gid;
preds[offset] +=
self->PredValue(batch[i], gid, info.GetRoot(ridx),
&feats, tree_begin, tree_end);
}
}
}
}
// initialize updater before using them
inline void InitUpdater() {
if (updaters.size() != 0) return;
@ -316,7 +376,6 @@ class GBTree : public GradientBooster {
inline void
BoostNewTrees(const std::vector<bst_gpair> &gpair,
DMatrix *p_fmat,
int64_t buffer_offset,
int bst_group,
std::vector<std::unique_ptr<RegTree> >* ret) {
this->InitUpdater();
@ -334,111 +393,50 @@ class GBTree : public GradientBooster {
for (auto& up : updaters) {
up->Update(gpair, p_fmat, new_trees);
}
// optimization, update buffer, if possible
// this is only under distributed column mode
// for safety check of lazy checkpoint
if (buffer_offset >= 0 &&
new_trees.size() == 1 && updaters.size() > 0 &&
updaters.back()->GetLeafPosition() != nullptr) {
CHECK_EQ(p_fmat->info().num_row, p_fmat->buffered_rowset().size());
this->UpdateBufferByPosition(p_fmat,
buffer_offset,
bst_group,
*new_trees[0],
updaters.back()->GetLeafPosition());
}
}
// commit new trees all at once
virtual void
CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
int bst_group) {
size_t old_ntree = trees.size();
for (size_t i = 0; i < new_trees.size(); ++i) {
trees.push_back(std::move(new_trees[i]));
tree_info.push_back(bst_group);
}
mparam.num_trees += static_cast<int>(new_trees.size());
}
// update buffer by pre-cached position
inline void UpdateBufferByPosition(DMatrix *p_fmat,
int64_t buffer_offset,
int bst_group,
const RegTree &new_tree,
const int* leaf_position) {
const RowSet& rowset = p_fmat->buffered_rowset();
const bst_omp_uint ndata = static_cast<bst_omp_uint>(rowset.size());
int pred_counter_error = 0, tid_error = 0;
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i];
const int64_t bid = this->BufferOffset(buffer_offset + ridx, bst_group);
const int tid = leaf_position[ridx];
if (pred_counter[bid] != trees.size()) {
pred_counter_error = 1;
continue;
// update cache entry
for (auto &kv : cache_) {
CacheEntry& e = kv.second;
if (e.predictions.size() == 0) {
PredLoopInternal<GBTree>(
e.data.get(), &(e.predictions),
0, trees.size(), true);
} else {
PredLoopInternal<GBTree>(
e.data.get(), &(e.predictions),
old_ntree, trees.size(), false);
}
if (tid < 0) {
tid_error = 1;
continue;
}
pred_buffer[bid] += new_tree[tid].leaf_value();
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
pred_buffer[bid + i + 1] += new_tree.leafvec(tid)[i];
}
pred_counter[bid] += tparam.num_parallel_tree;
}
CHECK(!pred_counter_error) << "incorrect pred_counter[bid]";
CHECK(!tid_error) << "tid cannot be negative";
}
// make a prediction for a single instance
inline void Pred(const RowBatch::Inst &inst,
int64_t buffer_index,
int bst_group,
unsigned root_index,
RegTree::FVec *p_feats,
float *out_pred,
size_t stride,
unsigned ntree_limit) {
size_t itop = 0;
float psum = 0.0f;
// sum of leaf vector
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
const int64_t bid = this->BufferOffset(buffer_index, bst_group);
// number of valid trees
unsigned treeleft = ntree_limit == 0 ? std::numeric_limits<unsigned>::max() : ntree_limit;
// load buffered results if any
if (bid >= 0 && ntree_limit == 0) {
itop = pred_counter[bid];
psum = pred_buffer[bid];
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
vec_psum[i] = pred_buffer[bid + i + 1];
inline float PredValue(const RowBatch::Inst &inst,
int bst_group,
unsigned root_index,
RegTree::FVec *p_feats,
unsigned tree_begin,
unsigned tree_end) {
float psum = 0.0f;
p_feats->Fill(inst);
for (size_t i = tree_begin; i < tree_end; ++i) {
if (tree_info[i] == bst_group) {
int tid = trees[i]->GetLeafIndex(*p_feats, root_index);
psum += (*trees[i])[tid].leaf_value();
}
}
if (itop != trees.size()) {
p_feats->Fill(inst);
for (size_t i = itop; i < trees.size(); ++i) {
if (tree_info[i] == bst_group) {
int tid = trees[i]->GetLeafIndex(*p_feats, root_index);
psum += (*trees[i])[tid].leaf_value();
for (int j = 0; j < mparam.size_leaf_vector; ++j) {
vec_psum[j] += trees[i]->leafvec(tid)[j];
}
if (--treeleft == 0) break;
}
}
p_feats->Drop(inst);
}
// updated the buffered results
if (bid >= 0 && ntree_limit == 0) {
pred_counter[bid] = static_cast<unsigned>(trees.size());
pred_buffer[bid] = psum;
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
pred_buffer[bid + i + 1] = vec_psum[i];
}
}
out_pred[0] = psum;
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
out_pred[stride * (i + 1)] = vec_psum[i];
}
p_feats->Drop(inst);
return psum;
}
// predict independent leaf index
inline void PredPath(DMatrix *p_fmat,
@ -446,6 +444,7 @@ class GBTree : public GradientBooster {
unsigned ntree_limit) {
const MetaInfo& info = p_fmat->info();
// number of valid trees
ntree_limit *= mparam.num_output_group;
if (ntree_limit == 0 || ntree_limit > trees.size()) {
ntree_limit = static_cast<unsigned>(trees.size());
}
@ -482,22 +481,9 @@ class GBTree : public GradientBooster {
}
}
}
/*! \return size of prediction buffer actually needed */
inline size_t PredBufferSize() const {
return mparam.num_output_group * num_pbuffer * (mparam.size_leaf_vector + 1);
}
/*!
* \brief get the buffer offset given a buffer index and group id
* \return calculated buffer offset
*/
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
if (buffer_index < 0) return -1;
size_t bidx = static_cast<size_t>(buffer_index);
CHECK_LT(bidx, num_pbuffer);
return (bidx + num_pbuffer * bst_group) * (mparam.size_leaf_vector + 1);
}
// --- data structure ---
// base margin
float base_margin_;
// training parameter
GBTreeTrainParam tparam;
// model parameter
@ -506,13 +492,8 @@ class GBTree : public GradientBooster {
std::vector<std::unique_ptr<RegTree> > trees;
/*! \brief some information indicator of the tree, reserved */
std::vector<int> tree_info;
/*! \brief predict buffer size */
size_t num_pbuffer;
/*! \brief prediction buffer */
std::vector<float> pred_buffer;
/*! \brief prediction buffer counter, remember the prediction */
std::vector<unsigned> pred_counter;
// ----training fields----
std::unordered_map<DMatrix*, CacheEntry> cache_;
// configurations for tree
std::vector<std::pair<std::string, std::string> > cfg;
// temporal storage for per thread
@ -524,7 +505,7 @@ class GBTree : public GradientBooster {
// dart
class Dart : public GBTree {
public:
Dart() {}
explicit Dart(float base_margin) : GBTree(base_margin) {}
void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
GBTree::Configure(cfg);
@ -550,44 +531,10 @@ class Dart : public GBTree {
// predict the leaf scores with dropout if ntree_limit = 0
void Predict(DMatrix* p_fmat,
int64_t buffer_offset,
std::vector<float>* out_preds,
unsigned ntree_limit) override {
DropTrees(ntree_limit);
const MetaInfo& info = p_fmat->info();
int nthread;
#pragma omp parallel
{
nthread = omp_get_num_threads();
}
InitThreadTemp(nthread);
std::vector<float> &preds = *out_preds;
const size_t stride = p_fmat->info().num_row * mparam.num_output_group;
preds.resize(stride * (mparam.size_leaf_vector+1));
// start collecting the prediction
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
// parallel over local batch
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
const int tid = omp_get_thread_num();
RegTree::FVec &feats = thread_temp[tid];
int64_t ridx = static_cast<int64_t>(batch.base_rowid + i);
CHECK_LT(static_cast<size_t>(ridx), info.num_row);
// loop over output groups
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
this->Pred(batch[i],
buffer_offset < 0 ? -1 : buffer_offset + ridx,
gid, info.GetRoot(ridx), &feats,
&preds[ridx * mparam.num_output_group + gid], stride,
ntree_limit);
}
}
}
PredLoopInternal<Dart>(p_fmat, out_preds, 0, ntree_limit, true);
}
void Predict(const SparseBatch::Inst& inst,
@ -599,20 +546,24 @@ class Dart : public GBTree {
thread_temp.resize(1, RegTree::FVec());
thread_temp[0].Init(mparam.num_feature);
}
out_preds->resize(mparam.num_output_group * (mparam.size_leaf_vector+1));
out_preds->resize(mparam.num_output_group);
ntree_limit *= mparam.num_output_group;
if (ntree_limit == 0 || ntree_limit > trees.size()) {
ntree_limit = static_cast<unsigned>(trees.size());
}
// loop over output groups
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
this->Pred(inst, -1, gid, root_index, &thread_temp[0],
&(*out_preds)[gid], mparam.num_output_group,
ntree_limit);
(*out_preds)[gid]
= PredValue(inst, gid, root_index,
&thread_temp[0], 0, ntree_limit) + base_margin_;
}
}
protected:
friend class GBTree;
// commit new trees all at once
virtual void
CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
int bst_group) {
void CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
int bst_group) override {
for (size_t i = 0; i < new_trees.size(); ++i) {
trees.push_back(std::move(new_trees[i]));
tree_info.push_back(bst_group);
@ -625,44 +576,25 @@ class Dart : public GBTree {
}
}
// predict the leaf scores without dropped trees
inline void Pred(const RowBatch::Inst &inst,
int64_t buffer_index,
int bst_group,
unsigned root_index,
RegTree::FVec *p_feats,
float *out_pred,
size_t stride,
unsigned ntree_limit) {
float psum = 0.0f;
// sum of leaf vector
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
const int64_t bid = this->BufferOffset(buffer_index, bst_group);
inline float PredValue(const RowBatch::Inst &inst,
int bst_group,
unsigned root_index,
RegTree::FVec *p_feats,
unsigned tree_begin,
unsigned tree_end) {
float psum = 0.0f;
p_feats->Fill(inst);
for (size_t i = 0; i < trees.size(); ++i) {
for (size_t i = tree_begin; i < tree_end; ++i) {
if (tree_info[i] == bst_group) {
bool drop = (std::find(idx_drop.begin(), idx_drop.end(), i) != idx_drop.end());
bool drop = (std::binary_search(idx_drop.begin(), idx_drop.end(), i));
if (!drop) {
int tid = trees[i]->GetLeafIndex(*p_feats, root_index);
psum += weight_drop[i] * (*trees[i])[tid].leaf_value();
for (int j = 0; j < mparam.size_leaf_vector; ++j) {
vec_psum[j] += weight_drop[i] * trees[i]->leafvec(tid)[j];
}
}
}
}
p_feats->Drop(inst);
// updated the buffered results
if (bid >= 0 && ntree_limit == 0) {
pred_counter[bid] = static_cast<unsigned>(trees.size());
pred_buffer[bid] = psum;
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
pred_buffer[bid + i + 1] = vec_psum[i];
}
}
out_pred[0] = psum;
for (int i = 0; i < mparam.size_leaf_vector; ++i) {
out_pred[stride * (i + 1)] = vec_psum[i];
}
return psum;
}
// select dropped trees
@ -744,13 +676,16 @@ DMLC_REGISTER_PARAMETER(DartTrainParam);
XGBOOST_REGISTER_GBM(GBTree, "gbtree")
.describe("Tree booster, gradient boosted trees.")
.set_body([]() {
return new GBTree();
.set_body([](const std::vector<std::shared_ptr<DMatrix> >& cached_mats, float base_margin) {
GBTree* p = new GBTree(base_margin);
p->InitCache(cached_mats);
return p;
});
XGBOOST_REGISTER_GBM(Dart, "dart")
.describe("Tree booster, dart.")
.set_body([]() {
return new Dart();
.set_body([](const std::vector<std::shared_ptr<DMatrix> >& cached_mats, float base_margin) {
GBTree* p = new Dart(base_margin);
return p;
});
} // namespace gbm
} // namespace xgboost

View File

@ -118,20 +118,8 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
*/
class LearnerImpl : public Learner {
public:
explicit LearnerImpl(const std::vector<DMatrix*>& cache_mats)
noexcept(false) {
// setup the cache setting in constructor.
CHECK_EQ(cache_.size(), 0);
size_t buffer_size = 0;
for (auto it = cache_mats.begin(); it != cache_mats.end(); ++it) {
// avoid duplication.
if (std::find(cache_mats.begin(), it, *it) != it) continue;
DMatrix* pmat = *it;
pmat->cache_learner_ptr_ = this;
cache_.push_back(CacheEntry(pmat, buffer_size, pmat->info().num_row));
buffer_size += pmat->info().num_row;
}
pred_buffer_size_ = buffer_size;
explicit LearnerImpl(const std::vector<std::shared_ptr<DMatrix> >& cache)
: cache_(cache) {
// boosted tree
name_obj_ = "reg:linear";
name_gbm_ = "gbtree";
@ -257,7 +245,7 @@ class LearnerImpl : public Learner {
<< "BoostLearner: wrong model format";
// duplicated code with LazyInitModel
obj_.reset(ObjFunction::Create(name_obj_));
gbm_.reset(GradientBooster::Create(name_gbm_));
gbm_.reset(GradientBooster::Create(name_gbm_, cache_, mparam.base_score));
gbm_->Load(fi);
if (mparam.contain_extra_attrs != 0) {
std::vector<std::pair<std::string, std::string> > attr;
@ -265,8 +253,6 @@ class LearnerImpl : public Learner {
attributes_ = std::map<std::string, std::string>(
attr.begin(), attr.end());
}
this->base_score_ = mparam.base_score;
gbm_->ResetPredBuffer(pred_buffer_size_);
cfg_["num_class"] = common::ToString(mparam.num_class);
cfg_["num_feature"] = common::ToString(mparam.num_feature);
obj_->Configure(cfg_.begin(), cfg_.end());
@ -294,7 +280,7 @@ class LearnerImpl : public Learner {
this->LazyInitDMatrix(train);
this->PredictRaw(train, &preds_);
obj_->GetGradient(preds_, train->info(), iter, &gpair_);
gbm_->DoBoost(train, this->FindBufferOffset(train), &gpair_);
gbm_->DoBoost(train, &gpair_, obj_.get());
}
void BoostOneIter(int iter,
@ -304,7 +290,7 @@ class LearnerImpl : public Learner {
common::GlobalRandom().seed(tparam.seed * kRandSeedMagic + iter);
}
this->LazyInitDMatrix(train);
gbm_->DoBoost(train, this->FindBufferOffset(train), in_gpair);
gbm_->DoBoost(train, in_gpair);
}
std::string EvalOneIter(int iter,
@ -435,28 +421,24 @@ class LearnerImpl : public Learner {
// estimate feature bound
unsigned num_feature = 0;
for (size_t i = 0; i < cache_.size(); ++i) {
CHECK(cache_[i] != nullptr);
num_feature = std::max(num_feature,
static_cast<unsigned>(cache_[i].mat_->info().num_col));
static_cast<unsigned>(cache_[i]->info().num_col));
}
// run allreduce on num_feature to find the maximum value
rabit::Allreduce<rabit::op::Max>(&num_feature, 1);
if (num_feature > mparam.num_feature) {
mparam.num_feature = num_feature;
}
// setup
cfg_["num_feature"] = common::ToString(mparam.num_feature);
CHECK(obj_.get() == nullptr && gbm_.get() == nullptr);
obj_.reset(ObjFunction::Create(name_obj_));
gbm_.reset(GradientBooster::Create(name_gbm_));
gbm_->Configure(cfg_.begin(), cfg_.end());
obj_->Configure(cfg_.begin(), cfg_.end());
// reset the base score
mparam.base_score = obj_->ProbToMargin(mparam.base_score);
this->base_score_ = mparam.base_score;
gbm_->ResetPredBuffer(pred_buffer_size_);
gbm_.reset(GradientBooster::Create(name_gbm_, cache_, mparam.base_score));
gbm_->Configure(cfg_.begin(), cfg_.end());
}
/*!
* \brief get un-transformed prediction
@ -471,29 +453,9 @@ class LearnerImpl : public Learner {
CHECK(gbm_.get() != nullptr)
<< "Predict must happen after Load or InitModel";
gbm_->Predict(data,
this->FindBufferOffset(data),
out_preds,
ntree_limit);
// add base margin
std::vector<float>& preds = *out_preds;
const bst_omp_uint ndata = static_cast<bst_omp_uint>(preds.size());
const std::vector<bst_float>& base_margin = data->info().base_margin;
if (base_margin.size() != 0) {
CHECK_EQ(preds.size(), base_margin.size())
<< "base_margin.size does not match with prediction size";
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
preds[j] += base_margin[j];
}
} else {
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
preds[j] += this->base_score_;
}
}
}
// cached size of predict buffer
size_t pred_buffer_size_;
// model parameter
LearnerModelParam mparam;
// training parameter
@ -514,31 +476,11 @@ class LearnerImpl : public Learner {
private:
/*! \brief random number transformation seed. */
static const int kRandSeedMagic = 127;
// cache entry object that helps handle feature caching
struct CacheEntry {
const DMatrix* mat_;
size_t buffer_offset_;
size_t num_row_;
CacheEntry(const DMatrix* mat, size_t buffer_offset, size_t num_row)
:mat_(mat), buffer_offset_(buffer_offset), num_row_(num_row) {}
};
// find internal buffer offset for certain matrix, if not exist, return -1
inline int64_t FindBufferOffset(const DMatrix* mat) const {
for (size_t i = 0; i < cache_.size(); ++i) {
if (cache_[i].mat_ == mat && mat->cache_learner_ptr_ == this) {
if (cache_[i].num_row_ == mat->info().num_row) {
return static_cast<int64_t>(cache_[i].buffer_offset_);
}
}
}
return -1;
}
/*! \brief the entries indicates that we have internal prediction cache */
std::vector<CacheEntry> cache_;
// internal cached dmatrix
std::vector<std::shared_ptr<DMatrix> > cache_;
};
Learner* Learner::Create(const std::vector<DMatrix*>& cache_data) {
Learner* Learner::Create(const std::vector<std::shared_ptr<DMatrix> >& cache_data) {
return new LearnerImpl(cache_data);
}
} // namespace xgboost