Replaced std::vector-based interfaces with HostDeviceVector-based interfaces. (#3116)

* Replaced std::vector-based interfaces with HostDeviceVector-based interfaces.

- replacement was performed in the learner, boosters, predictors,
  updaters, and objective functions
- only interfaces used in training were replaced;
  interfaces like PredictInstance() still use std::vector
- refactoring necessary for replacement of interfaces was also performed,
  such as using HostDeviceVector in prediction cache

* HostDeviceVector-based interfaces for custom objective function example plugin.
This commit is contained in:
Andrew V. Adinetz
2018-02-28 01:00:04 +01:00
committed by Rory Mitchell
parent 11bfa8584d
commit d5992dd881
38 changed files with 371 additions and 519 deletions

View File

@@ -22,17 +22,6 @@ TreeUpdater* TreeUpdater::Create(const std::string& name) {
return (e->body)();
}
void TreeUpdater::Update(HostDeviceVector<bst_gpair>* gpair,
DMatrix* data,
const std::vector<RegTree*>& trees) {
Update(gpair->data_h(), data, trees);
}
bool TreeUpdater::UpdatePredictionCache(const DMatrix* data,
HostDeviceVector<bst_float>* out_preds) {
return UpdatePredictionCache(data, &out_preds->data_h());
}
} // namespace xgboost
namespace xgboost {

View File

@@ -26,7 +26,7 @@ class ColMaker: public TreeUpdater {
param.InitAllowUnknown(args);
}
void Update(const std::vector<bst_gpair> &gpair,
void Update(HostDeviceVector<bst_gpair> *gpair,
DMatrix* dmat,
const std::vector<RegTree*> &trees) override {
TStats::CheckInfo(dmat->info());
@@ -37,7 +37,7 @@ class ColMaker: public TreeUpdater {
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
Builder builder(param);
builder.Update(gpair, dmat, trees[i]);
builder.Update(gpair->data_h(), dmat, trees[i]);
}
param.learning_rate = lr;
}
@@ -806,13 +806,13 @@ class DistColMaker : public ColMaker<TStats, TConstraint> {
param.InitAllowUnknown(args);
pruner->Init(args);
}
void Update(const std::vector<bst_gpair> &gpair,
void Update(HostDeviceVector<bst_gpair> *gpair,
DMatrix* dmat,
const std::vector<RegTree*> &trees) override {
TStats::CheckInfo(dmat->info());
CHECK_EQ(trees.size(), 1U) << "DistColMaker: only support one tree at a time";
// build the tree
builder.Update(gpair, dmat, trees[0]);
builder.Update(gpair->data_h(), dmat, trees[0]);
//// prune the tree, note that pruner will sync the tree
pruner->Update(gpair, dmat, trees);
// update position after the tree is pruned
@@ -967,7 +967,7 @@ class TreeUpdaterSwitch : public TreeUpdater {
inner_->Init(args);
}
void Update(const std::vector<bst_gpair>& gpair,
void Update(HostDeviceVector<bst_gpair>* gpair,
DMatrix* data,
const std::vector<RegTree*>& trees) override {
CHECK(inner_ != nullptr);

View File

@@ -55,7 +55,7 @@ class FastHistMaker: public TreeUpdater {
is_gmat_initialized_ = false;
}
void Update(const std::vector<bst_gpair>& gpair,
void Update(HostDeviceVector<bst_gpair>* gpair,
DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
TStats::CheckInfo(dmat->info());
@@ -82,13 +82,14 @@ class FastHistMaker: public TreeUpdater {
builder_.reset(new Builder(param, fhparam, std::move(pruner_)));
}
for (size_t i = 0; i < trees.size(); ++i) {
builder_->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, trees[i]);
builder_->Update
(gmat_, gmatb_, column_matrix_, gpair, dmat, trees[i]);
}
param.learning_rate = lr;
}
bool UpdatePredictionCache(const DMatrix* data,
std::vector<bst_float>* out_preds) override {
HostDeviceVector<bst_float>* out_preds) override {
if (!builder_ || param.subsample < 1.0f) {
return false;
} else {
@@ -139,7 +140,7 @@ class FastHistMaker: public TreeUpdater {
virtual void Update(const GHistIndexMatrix& gmat,
const GHistIndexBlockMatrix& gmatb,
const ColumnMatrix& column_matrix,
const std::vector<bst_gpair>& gpair,
HostDeviceVector<bst_gpair>* gpair,
DMatrix* p_fmat,
RegTree* p_tree) {
double gstart = dmlc::GetTime();
@@ -154,8 +155,10 @@ class FastHistMaker: public TreeUpdater {
double time_evaluate_split = 0;
double time_apply_split = 0;
std::vector<bst_gpair>& gpair_h = gpair->data_h();
tstart = dmlc::GetTime();
this->InitData(gmat, gpair, *p_fmat, *p_tree);
this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
std::vector<bst_uint> feat_set = feat_index;
time_init_data = dmlc::GetTime() - tstart;
@@ -165,11 +168,11 @@ class FastHistMaker: public TreeUpdater {
for (int nid = 0; nid < p_tree->param.num_roots; ++nid) {
tstart = dmlc::GetTime();
hist_.AddHistRow(nid);
BuildHist(gpair, row_set_collection_[nid], gmat, gmatb, feat_set, hist_[nid]);
BuildHist(gpair_h, row_set_collection_[nid], gmat, gmatb, feat_set, hist_[nid]);
time_build_hist += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
this->InitNewNode(nid, gmat, gpair, *p_fmat, *p_tree);
this->InitNewNode(nid, gmat, gpair_h, *p_fmat, *p_tree);
time_init_new_node += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
@@ -200,17 +203,17 @@ class FastHistMaker: public TreeUpdater {
hist_.AddHistRow(cleft);
hist_.AddHistRow(cright);
if (row_set_collection_[cleft].size() < row_set_collection_[cright].size()) {
BuildHist(gpair, row_set_collection_[cleft], gmat, gmatb, feat_set, hist_[cleft]);
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, feat_set, hist_[cleft]);
SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
} else {
BuildHist(gpair, row_set_collection_[cright], gmat, gmatb, feat_set, hist_[cright]);
BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, feat_set, hist_[cright]);
SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]);
}
time_build_hist += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
this->InitNewNode(cleft, gmat, gpair, *p_fmat, *p_tree);
this->InitNewNode(cright, gmat, gpair, *p_fmat, *p_tree);
this->InitNewNode(cleft, gmat, gpair_h, *p_fmat, *p_tree);
this->InitNewNode(cright, gmat, gpair_h, *p_fmat, *p_tree);
time_init_new_node += dmlc::GetTime() - tstart;
tstart = dmlc::GetTime();
@@ -293,8 +296,8 @@ class FastHistMaker: public TreeUpdater {
}
inline bool UpdatePredictionCache(const DMatrix* data,
std::vector<bst_float>* p_out_preds) {
std::vector<bst_float>& out_preds = *p_out_preds;
HostDeviceVector<bst_float>* p_out_preds) {
std::vector<bst_float>& out_preds = p_out_preds->data_h();
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
// conjunction with Update().

View File

@@ -512,7 +512,7 @@ class GPUMaker : public TreeUpdater {
maxLeaves = 1 << param.max_depth;
}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
void Update(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
GradStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
@@ -530,7 +530,7 @@ class GPUMaker : public TreeUpdater {
param.learning_rate = lr;
}
/// @note: Update should be only after Init!!
void UpdateTree(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
void UpdateTree(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
RegTree* hTree) {
if (!allocated) {
setupOneTimeData(dmat);
@@ -687,11 +687,11 @@ class GPUMaker : public TreeUpdater {
assignColIds<<<nCols, 512>>>(colIds.data(), colOffsets.data());
}
void transferGrads(const std::vector<bst_gpair>& gpair) {
void transferGrads(HostDeviceVector<bst_gpair>* gpair) {
// HACK
dh::safe_cuda(cudaMemcpy(gradsInst.data(), &(gpair[0]),
dh::safe_cuda(cudaMemcpy(gradsInst.data(), gpair->ptr_d(param.gpu_id),
sizeof(bst_gpair) * nRows,
cudaMemcpyHostToDevice));
cudaMemcpyDefault));
// evaluate the full-grad reduction for the root node
dh::sumReduction<bst_gpair>(tmp_mem, gradsInst, gradSums, nRows);
}

View File

@@ -506,27 +506,9 @@ class GPUHistMaker : public TreeUpdater {
monitor.Init("updater_gpu_hist", param.debug_verbose);
}
void Update(const std::vector<bst_gpair>& gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
monitor.Start("Update", dList);
// TODO(canonizer): move it into the class if this ever becomes a bottleneck
HostDeviceVector<bst_gpair> gpair_d(gpair.size(), param.gpu_id);
dh::safe_cuda(cudaSetDevice(param.gpu_id));
thrust::copy(gpair.begin(), gpair.end(), gpair_d.tbegin(param.gpu_id));
Update(&gpair_d, dmat, trees);
monitor.Stop("Update", dList);
}
void Update(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) override {
monitor.Start("Update", dList);
UpdateHelper(gpair, dmat, trees);
monitor.Stop("Update", dList);
}
private:
void UpdateHelper(HostDeviceVector<bst_gpair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
GradStats::CheckInfo(dmat->info());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
@@ -541,9 +523,9 @@ class GPUHistMaker : public TreeUpdater {
LOG(FATAL) << "GPU plugin exception: " << e.what() << std::endl;
}
param.learning_rate = lr;
monitor.Stop("Update", dList);
}
public:
void InitDataOnce(DMatrix* dmat) {
info = &dmat->info();
monitor.Start("Quantiles", dList);
@@ -876,16 +858,6 @@ class GPUHistMaker : public TreeUpdater {
omp_set_num_threads(nthread);
}
bool UpdatePredictionCache(const DMatrix* data,
std::vector<bst_float>* p_out_preds) override {
return false;
}
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
return false;
}
struct ExpandEntry {
int nid;
int depth;

View File

@@ -21,7 +21,7 @@ DMLC_REGISTRY_FILE_TAG(updater_histmaker);
template<typename TStats>
class HistMaker: public BaseMaker {
public:
void Update(const std::vector<bst_gpair> &gpair,
void Update(HostDeviceVector<bst_gpair> *gpair,
DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override {
TStats::CheckInfo(p_fmat->info());
@@ -30,7 +30,7 @@ class HistMaker: public BaseMaker {
param.learning_rate = lr / trees.size();
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
this->Update(gpair, p_fmat, trees[i]);
this->Update(gpair->data_h(), p_fmat, trees[i]);
}
param.learning_rate = lr;
}

View File

@@ -29,7 +29,7 @@ class TreePruner: public TreeUpdater {
syncher->Init(args);
}
// update the tree, do pruning
void Update(const std::vector<bst_gpair> &gpair,
void Update(HostDeviceVector<bst_gpair> *gpair,
DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override {
// rescale learning rate according to size of trees

View File

@@ -25,10 +25,11 @@ class TreeRefresher: public TreeUpdater {
param.InitAllowUnknown(args);
}
// update the tree, do pruning
void Update(const std::vector<bst_gpair> &gpair,
void Update(HostDeviceVector<bst_gpair> *gpair,
DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override {
if (trees.size() == 0) return;
std::vector<bst_gpair> &gpair_h = gpair->data_h();
// number of threads
// thread temporal space
std::vector<std::vector<TStats> > stemp;
@@ -71,7 +72,7 @@ class TreeRefresher: public TreeUpdater {
feats.Fill(inst);
int offset = 0;
for (size_t j = 0; j < trees.size(); ++j) {
AddStats(*trees[j], feats, gpair, info, ridx,
AddStats(*trees[j], feats, gpair_h, info, ridx,
dmlc::BeginPtr(stemp[tid]) + offset);
offset += trees[j]->param.num_nodes;
}

View File

@@ -22,7 +22,7 @@ DMLC_REGISTRY_FILE_TAG(updater_skmaker);
class SketchMaker: public BaseMaker {
public:
void Update(const std::vector<bst_gpair> &gpair,
void Update(HostDeviceVector<bst_gpair> *gpair,
DMatrix *p_fmat,
const std::vector<RegTree*> &trees) override {
// rescale learning rate according to size of trees
@@ -30,7 +30,7 @@ class SketchMaker: public BaseMaker {
param.learning_rate = lr / trees.size();
// build tree
for (size_t i = 0; i < trees.size(); ++i) {
this->Update(gpair, p_fmat, trees[i]);
this->Update(gpair->data_h(), p_fmat, trees[i]);
}
param.learning_rate = lr;
}

View File

@@ -23,7 +23,7 @@ class TreeSyncher: public TreeUpdater {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& args) override {}
void Update(const std::vector<bst_gpair> &gpair,
void Update(HostDeviceVector<bst_gpair> *gpair,
DMatrix* dmat,
const std::vector<RegTree*> &trees) override {
if (rabit::GetWorldSize() == 1) return;