Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. (#3116)
* Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. - replacement was performed in the learner, boosters, predictors, updaters, and objective functions - only interfaces used in training were replaced; interfaces like PredictInstance() still use std::vector - refactoring necessary for replacement of interfaces was also performed, such as using HostDeviceVector in prediction cache * HostDeviceVector-based interfaces for custom objective function example plugin.
This commit is contained in:
committed by
Rory Mitchell
parent
11bfa8584d
commit
d5992dd881
@@ -76,8 +76,10 @@ class GBLinear : public GradientBooster {
|
||||
void Save(dmlc::Stream* fo) const override {
|
||||
model.Save(fo);
|
||||
}
|
||||
void DoBoost(DMatrix *p_fmat, std::vector<bst_gpair> *in_gpair,
|
||||
ObjFunction *obj) override {
|
||||
|
||||
void DoBoost(DMatrix *p_fmat,
|
||||
HostDeviceVector<bst_gpair> *in_gpair,
|
||||
ObjFunction* obj) override {
|
||||
monitor.Start("DoBoost");
|
||||
|
||||
if (!p_fmat->HaveColAccess(false)) {
|
||||
@@ -91,14 +93,15 @@ class GBLinear : public GradientBooster {
|
||||
this->LazySumWeights(p_fmat);
|
||||
|
||||
if (!this->CheckConvergence()) {
|
||||
updater->Update(in_gpair, p_fmat, &model, sum_instance_weight);
|
||||
updater->Update(&in_gpair->data_h(), p_fmat, &model, sum_instance_weight);
|
||||
}
|
||||
this->UpdatePredictionCache();
|
||||
|
||||
monitor.Stop("DoBoost");
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
|
||||
void PredictBatch(DMatrix *p_fmat,
|
||||
HostDeviceVector<bst_float> *out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
monitor.Start("PredictBatch");
|
||||
CHECK_EQ(ntree_limit, 0U)
|
||||
@@ -109,9 +112,9 @@ class GBLinear : public GradientBooster {
|
||||
if (it != cache_.end() && it->second.predictions.size() != 0) {
|
||||
std::vector<bst_float> &y = it->second.predictions;
|
||||
out_preds->resize(y.size());
|
||||
std::copy(y.begin(), y.end(), out_preds->begin());
|
||||
std::copy(y.begin(), y.end(), out_preds->data_h().begin());
|
||||
} else {
|
||||
this->PredictBatchInternal(p_fmat, out_preds);
|
||||
this->PredictBatchInternal(p_fmat, &out_preds->data_h());
|
||||
}
|
||||
monitor.Stop("PredictBatch");
|
||||
}
|
||||
|
||||
@@ -22,18 +22,6 @@ GradientBooster* GradientBooster::Create(
|
||||
return (e->body)(cache_mats, base_margin);
|
||||
}
|
||||
|
||||
void GradientBooster::DoBoost(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_gpair>* in_gpair,
|
||||
ObjFunction* obj) {
|
||||
DoBoost(p_fmat, &in_gpair->data_h(), obj);
|
||||
}
|
||||
|
||||
void GradientBooster::PredictBatch(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit) {
|
||||
PredictBatch(dmat, &out_preds->data_h(), ntree_limit);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@@ -180,22 +180,39 @@ class GBTree : public GradientBooster {
|
||||
tparam.updater_seq.find("distcol") != std::string::npos;
|
||||
}
|
||||
|
||||
void DoBoost(DMatrix* p_fmat,
|
||||
std::vector<bst_gpair>* in_gpair,
|
||||
ObjFunction* obj) override {
|
||||
DoBoostHelper(p_fmat, in_gpair, obj);
|
||||
}
|
||||
|
||||
void DoBoost(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_gpair>* in_gpair,
|
||||
ObjFunction* obj) override {
|
||||
DoBoostHelper(p_fmat, in_gpair, obj);
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
predictor->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
|
||||
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
|
||||
const int ngroup = model_.param.num_output_group;
|
||||
monitor.Start("BoostNewTrees");
|
||||
if (ngroup == 1) {
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
|
||||
new_trees.push_back(std::move(ret));
|
||||
} else {
|
||||
CHECK_EQ(in_gpair->size() % ngroup, 0U)
|
||||
<< "must have exactly ngroup*nrow gpairs";
|
||||
// TODO(canonizer): perform this on GPU if HostDeviceVector has device set.
|
||||
HostDeviceVector<bst_gpair> tmp(in_gpair->size() / ngroup,
|
||||
bst_gpair(), in_gpair->device());
|
||||
std::vector<bst_gpair>& gpair_h = in_gpair->data_h();
|
||||
bst_omp_uint nsize = static_cast<bst_omp_uint>(tmp.size());
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
std::vector<bst_gpair>& tmp_h = tmp.data_h();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
tmp_h[i] = gpair_h[i * ngroup + gid];
|
||||
}
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(&tmp, p_fmat, gid, &ret);
|
||||
new_trees.push_back(std::move(ret));
|
||||
}
|
||||
}
|
||||
monitor.Stop("BoostNewTrees");
|
||||
monitor.Start("CommitModel");
|
||||
this->CommitModel(std::move(new_trees));
|
||||
monitor.Stop("CommitModel");
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
@@ -251,48 +268,11 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
}
|
||||
|
||||
// TVec is either std::vector<bst_gpair> or HostDeviceVector<bst_gpair>
|
||||
template <typename TVec>
|
||||
void DoBoostHelper(DMatrix* p_fmat,
|
||||
TVec* in_gpair,
|
||||
ObjFunction* obj) {
|
||||
std::vector<std::vector<std::unique_ptr<RegTree> > > new_trees;
|
||||
const int ngroup = model_.param.num_output_group;
|
||||
monitor.Start("BoostNewTrees");
|
||||
if (ngroup == 1) {
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
|
||||
new_trees.push_back(std::move(ret));
|
||||
} else {
|
||||
CHECK_EQ(in_gpair->size() % ngroup, 0U)
|
||||
<< "must have exactly ngroup*nrow gpairs";
|
||||
std::vector<bst_gpair> tmp(in_gpair->size() / ngroup);
|
||||
auto& gpair_h = HostDeviceVector<bst_gpair>::data_h(in_gpair);
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
bst_omp_uint nsize = static_cast<bst_omp_uint>(tmp.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
tmp[i] = gpair_h[i * ngroup + gid];
|
||||
}
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(&tmp, p_fmat, gid, &ret);
|
||||
new_trees.push_back(std::move(ret));
|
||||
}
|
||||
}
|
||||
monitor.Stop("BoostNewTrees");
|
||||
monitor.Start("CommitModel");
|
||||
this->CommitModel(std::move(new_trees));
|
||||
monitor.Stop("CommitModel");
|
||||
}
|
||||
|
||||
// do group specific group
|
||||
// TVec is either const std::vector<bst_gpair> or HostDeviceVector<bst_gpair>
|
||||
template <typename TVec>
|
||||
inline void
|
||||
BoostNewTrees(TVec* gpair,
|
||||
DMatrix *p_fmat,
|
||||
int bst_group,
|
||||
std::vector<std::unique_ptr<RegTree> >* ret) {
|
||||
inline void BoostNewTrees(HostDeviceVector<bst_gpair>* gpair,
|
||||
DMatrix *p_fmat,
|
||||
int bst_group,
|
||||
std::vector<std::unique_ptr<RegTree> >* ret) {
|
||||
this->InitUpdater();
|
||||
std::vector<RegTree*> new_trees;
|
||||
ret->clear();
|
||||
@@ -315,23 +295,8 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
}
|
||||
// update the trees
|
||||
for (auto& up : updaters) {
|
||||
UpdateHelper(up.get(), gpair, p_fmat, new_trees);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateHelper(TreeUpdater* updater,
|
||||
std::vector<bst_gpair>* gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<RegTree*>& new_trees) {
|
||||
updater->Update(*gpair, p_fmat, new_trees);
|
||||
}
|
||||
|
||||
void UpdateHelper(TreeUpdater* updater,
|
||||
HostDeviceVector<bst_gpair>* gpair,
|
||||
DMatrix *p_fmat,
|
||||
const std::vector<RegTree*>& new_trees) {
|
||||
updater->Update(gpair, p_fmat, new_trees);
|
||||
for (auto& up : updaters)
|
||||
up->Update(gpair, p_fmat, new_trees);
|
||||
}
|
||||
|
||||
// commit new trees all at once
|
||||
@@ -389,10 +354,10 @@ class Dart : public GBTree {
|
||||
|
||||
// predict the leaf scores with dropout if ntree_limit = 0
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
DropTrees(ntree_limit);
|
||||
PredLoopInternal<Dart>(p_fmat, out_preds, 0, ntree_limit, true);
|
||||
PredLoopInternal<Dart>(p_fmat, &out_preds->data_h(), 0, ntree_limit, true);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparseBatch::Inst& inst,
|
||||
|
||||
Reference in New Issue
Block a user