Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. (#3116)
* Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. - replacement was performed in the learner, boosters, predictors, updaters, and objective functions - only interfaces used in training were replaced; interfaces like PredictInstance() still use std::vector - refactoring necessary for replacement of interfaces was also performed, such as using HostDeviceVector in prediction cache * HostDeviceVector-based interfaces for custom objective function example plugin.
This commit is contained in:
committed by
Rory Mitchell
parent
11bfa8584d
commit
d5992dd881
@@ -362,17 +362,17 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
this->LazyInitDMatrix(train);
|
||||
monitor.Start("PredictRaw");
|
||||
this->PredictRaw(train, &preds2_);
|
||||
this->PredictRaw(train, &preds_);
|
||||
monitor.Stop("PredictRaw");
|
||||
monitor.Start("GetGradient");
|
||||
obj_->GetGradient(&preds2_, train->info(), iter, &gpair_);
|
||||
obj_->GetGradient(&preds_, train->info(), iter, &gpair_);
|
||||
monitor.Stop("GetGradient");
|
||||
gbm_->DoBoost(train, &gpair_, obj_.get());
|
||||
monitor.Stop("UpdateOneIter");
|
||||
}
|
||||
|
||||
void BoostOneIter(int iter, DMatrix* train,
|
||||
std::vector<bst_gpair>* in_gpair) override {
|
||||
HostDeviceVector<bst_gpair>* in_gpair) override {
|
||||
monitor.Start("BoostOneIter");
|
||||
if (tparam.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(tparam.seed * kRandSeedMagic + iter);
|
||||
@@ -395,7 +395,7 @@ class LearnerImpl : public Learner {
|
||||
obj_->EvalTransform(&preds_);
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||
<< ev->Eval(preds_, data_sets[i]->info(), tparam.dsplit == 2);
|
||||
<< ev->Eval(preds_.data_h(), data_sets[i]->info(), tparam.dsplit == 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -438,19 +438,20 @@ class LearnerImpl : public Learner {
|
||||
this->PredictRaw(data, &preds_);
|
||||
obj_->EvalTransform(&preds_);
|
||||
return std::make_pair(metric,
|
||||
ev->Eval(preds_, data->info(), tparam.dsplit == 2));
|
||||
ev->Eval(preds_.data_h(), data->info(), tparam.dsplit == 2));
|
||||
}
|
||||
|
||||
void Predict(DMatrix* data, bool output_margin,
|
||||
std::vector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
bool pred_leaf, bool pred_contribs, bool approx_contribs,
|
||||
bool pred_interactions) const override {
|
||||
if (pred_contribs) {
|
||||
gbm_->PredictContribution(data, out_preds, ntree_limit, approx_contribs);
|
||||
gbm_->PredictContribution(data, &out_preds->data_h(), ntree_limit, approx_contribs);
|
||||
} else if (pred_interactions) {
|
||||
gbm_->PredictInteractionContributions(data, out_preds, ntree_limit, approx_contribs);
|
||||
gbm_->PredictInteractionContributions(data, &out_preds->data_h(), ntree_limit,
|
||||
approx_contribs);
|
||||
} else if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data, out_preds, ntree_limit);
|
||||
gbm_->PredictLeaf(data, &out_preds->data_h(), ntree_limit);
|
||||
} else {
|
||||
this->PredictRaw(data, out_preds, ntree_limit);
|
||||
if (!output_margin) {
|
||||
@@ -546,12 +547,6 @@ class LearnerImpl : public Learner {
|
||||
* \param ntree_limit limit number of trees used for boosted tree
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
*/
|
||||
inline void PredictRaw(DMatrix* data, std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) const {
|
||||
CHECK(gbm_.get() != nullptr)
|
||||
<< "Predict must happen after Load or InitModel";
|
||||
gbm_->PredictBatch(data, out_preds, ntree_limit);
|
||||
}
|
||||
inline void PredictRaw(DMatrix* data, HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) const {
|
||||
CHECK(gbm_.get() != nullptr)
|
||||
@@ -572,8 +567,7 @@ class LearnerImpl : public Learner {
|
||||
// name of objective function
|
||||
std::string name_obj_;
|
||||
// temporal storages for prediction
|
||||
std::vector<bst_float> preds_;
|
||||
HostDeviceVector<bst_float> preds2_;
|
||||
HostDeviceVector<bst_float> preds_;
|
||||
// gradient pairs
|
||||
HostDeviceVector<bst_gpair> gpair_;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user