Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. (#3116)
* Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. - replacement was performed in the learner, boosters, predictors, updaters, and objective functions - only interfaces used in training were replaced; interfaces like PredictInstance() still use std::vector - refactoring necessary for replacement of interfaces was also performed, such as using HostDeviceVector in prediction cache * HostDeviceVector-based interfaces for custom objective function example plugin.
This commit is contained in:
committed by
Rory Mitchell
parent
11bfa8584d
commit
d5992dd881
@@ -104,14 +104,43 @@ class CPUPredictor : public Predictor {
|
||||
tree_begin, ntree_limit);
|
||||
}
|
||||
|
||||
public:
|
||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) override {
|
||||
PredictBatch(dmat, &out_preds->data_h(), model, tree_begin, ntree_limit);
|
||||
bool PredictFromCache(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit) {
|
||||
if (ntree_limit == 0 ||
|
||||
ntree_limit * model.param.num_output_group >= model.trees.size()) {
|
||||
auto it = cache_.find(dmat);
|
||||
if (it != cache_.end()) {
|
||||
HostDeviceVector<bst_float>& y = it->second.predictions;
|
||||
if (y.size() != 0) {
|
||||
out_preds->resize(y.size());
|
||||
std::copy(y.data_h().begin(), y.data_h().end(),
|
||||
out_preds->data_h().begin());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, std::vector<bst_float>* out_preds,
|
||||
void InitOutPredictions(const MetaInfo& info,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model) const {
|
||||
size_t n = model.param.num_output_group * info.num_row;
|
||||
const std::vector<bst_float>& base_margin = info.base_margin;
|
||||
out_preds->resize(n);
|
||||
std::vector<bst_float>& out_preds_h = out_preds->data_h();
|
||||
if (base_margin.size() != 0) {
|
||||
CHECK_EQ(out_preds->size(), n);
|
||||
std::copy(base_margin.begin(), base_margin.end(), out_preds_h.begin());
|
||||
} else {
|
||||
std::fill(out_preds_h.begin(), out_preds_h.end(), model.base_margin);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) override {
|
||||
if (this->PredictFromCache(dmat, out_preds, model, ntree_limit)) {
|
||||
@@ -125,12 +154,14 @@ class CPUPredictor : public Predictor {
|
||||
ntree_limit = static_cast<unsigned>(model.trees.size());
|
||||
}
|
||||
|
||||
this->PredLoopInternal(dmat, out_preds, model, tree_begin, ntree_limit);
|
||||
this->PredLoopInternal(dmat, &out_preds->data_h(), model,
|
||||
tree_begin, ntree_limit);
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(const gbm::GBTreeModel& model,
|
||||
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
|
||||
int num_new_trees) override {
|
||||
void UpdatePredictionCache(
|
||||
const gbm::GBTreeModel& model,
|
||||
std::vector<std::unique_ptr<TreeUpdater>>* updaters,
|
||||
int num_new_trees) override {
|
||||
int old_ntree = model.trees.size() - num_new_trees;
|
||||
// update cache entry
|
||||
for (auto& kv : cache_) {
|
||||
@@ -138,7 +169,7 @@ class CPUPredictor : public Predictor {
|
||||
|
||||
if (e.predictions.size() == 0) {
|
||||
InitOutPredictions(e.data->info(), &(e.predictions), model);
|
||||
PredLoopInternal(e.data.get(), &(e.predictions), model, 0,
|
||||
PredLoopInternal(e.data.get(), &(e.predictions.data_h()), model, 0,
|
||||
model.trees.size());
|
||||
} else if (model.param.num_output_group == 1 && updaters->size() > 0 &&
|
||||
num_new_trees == 1 &&
|
||||
@@ -146,7 +177,7 @@ class CPUPredictor : public Predictor {
|
||||
&(e.predictions))) {
|
||||
{} // do nothing
|
||||
} else {
|
||||
PredLoopInternal(e.data.get(), &(e.predictions), model, old_ntree,
|
||||
PredLoopInternal(e.data.get(), &(e.predictions.data_h()), model, old_ntree,
|
||||
model.trees.size());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user