Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. (#3116)
* Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. - replacement was performed in the learner, boosters, predictors, updaters, and objective functions - only interfaces used in training were replaced; interfaces like PredictInstance() still use std::vector - refactoring necessary for replacement of interfaces was also performed, such as using HostDeviceVector in prediction cache * HostDeviceVector-based interfaces for custom objective function example plugin.
This commit is contained in:
committed by
Rory Mitchell
parent
11bfa8584d
commit
d5992dd881
@@ -506,27 +506,9 @@ class GPUHistMaker : public TreeUpdater {
|
||||
monitor.Init("updater_gpu_hist", param.debug_verbose);
|
||||
}
|
||||
|
||||
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
monitor.Start("Update", dList);
|
||||
// TODO(canonizer): move it into the class if this ever becomes a bottleneck
|
||||
HostDeviceVector<bst_gpair> gpair_d(gpair.size(), param.gpu_id);
|
||||
dh::safe_cuda(cudaSetDevice(param.gpu_id));
|
||||
thrust::copy(gpair.begin(), gpair.end(), gpair_d.tbegin(param.gpu_id));
|
||||
Update(&gpair_d, dmat, trees);
|
||||
monitor.Stop("Update", dList);
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) override {
|
||||
monitor.Start("Update", dList);
|
||||
UpdateHelper(gpair, dmat, trees);
|
||||
monitor.Stop("Update", dList);
|
||||
}
|
||||
|
||||
private:
|
||||
void UpdateHelper(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
GradStats::CheckInfo(dmat->info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
@@ -541,9 +523,9 @@ class GPUHistMaker : public TreeUpdater {
|
||||
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
monitor.Stop("Update", dList);
|
||||
}
|
||||
|
||||
public:
|
||||
void InitDataOnce(DMatrix* dmat) {
|
||||
info = &dmat->info();
|
||||
monitor.Start("Quantiles", dList);
|
||||
@@ -876,16 +858,6 @@ class GPUHistMaker : public TreeUpdater {
|
||||
omp_set_num_threads(nthread);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
std::vector<bst_float>* p_out_preds) override {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
|
||||
return false;
|
||||
}
|
||||
|
||||
struct ExpandEntry {
|
||||
int nid;
|
||||
int depth;
|
||||
|
||||
Reference in New Issue
Block a user