Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. (#3116)

* Replaced std::vector-based interfaces with HostDeviceVector-based interfaces.

- replacement was performed in the learner, boosters, predictors,
  updaters, and objective functions
- only interfaces used in training were replaced;
  interfaces like PredictInstance() still use std::vector
- refactoring necessary for replacement of interfaces was also performed,
  such as using HostDeviceVector in prediction cache

* HostDeviceVector-based interfaces for custom objective function example plugin.
This commit is contained in:
Andrew V. Adinetz
2018-02-28 01:00:04 +01:00
committed by Rory Mitchell
parent 11bfa8584d
commit d5992dd881
38 changed files with 371 additions and 519 deletions

View File

@@ -506,27 +506,9 @@ class GPUHistMaker : public TreeUpdater {
monitor.Init("updater_gpu_hist", param.debug_verbose);
}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
monitor.Start("Update", dList);
// TODO(canonizer): move it into the class if this ever becomes a bottleneck
HostDeviceVector<bst_gpair> gpair_d(gpair.size(), param.gpu_id);
dh::safe_cuda(cudaSetDevice(param.gpu_id));
thrust::copy(gpair.begin(), gpair.end(), gpair_d.tbegin(param.gpu_id));
Update(&gpair_d, dmat, trees);
monitor.Stop("Update", dList);
}
void Update(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
monitor.Start("Update", dList);
UpdateHelper(gpair, dmat, trees);
monitor.Stop("Update", dList);
}
private:
void UpdateHelper(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
GradStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
@@ -541,9 +523,9 @@ class GPUHistMaker : public TreeUpdater {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
monitor.Stop("Update", dList);
}
public:
void InitDataOnce(DMatrix* dmat) {
info = &dmat->info();
monitor.Start("Quantiles", dList);
@@ -876,16 +858,6 @@ class GPUHistMaker : public TreeUpdater {
omp_set_num_threads(nthread);
}
bool UpdatePredictionCache(const DMatrix* data,
std::vector<bst_float>* p_out_preds) override {
return false;
}
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
return false;
}
struct ExpandEntry {
int nid;
int depth;