Pass shared pointer instead of raw pointer to Learner. (#5302)

Extracted from https://github.com/dmlc/xgboost/pull/5220 .
This commit is contained in:
Jiaming Yuan
2020-02-11 14:16:38 +08:00
committed by GitHub
parent 2e0067e790
commit 29eeea709a
12 changed files with 97 additions and 73 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2019 by Contributors
* Copyright 2014-2020 by Contributors
* \file learner.cc
* \brief Implementation of learning algorithm.
* \author Tianqi Chen
@@ -691,7 +691,7 @@ class LearnerImpl : public Learner {
return gbm_->DumpModel(fmap, with_stats, format);
}
void UpdateOneIter(int iter, DMatrix* train) override {
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
monitor_.Start("UpdateOneIter");
TrainingObserver::Instance().Update(iter);
this->Configure();
@@ -699,23 +699,23 @@ class LearnerImpl : public Learner {
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
}
this->CheckDataSplitMode();
this->ValidateDMatrix(train);
this->ValidateDMatrix(train.get());
monitor_.Start("PredictRaw");
this->PredictRaw(train, &preds_[train], true);
this->PredictRaw(train.get(), &preds_[train.get()], true);
monitor_.Stop("PredictRaw");
TrainingObserver::Instance().Observe(preds_[train], "Predictions");
TrainingObserver::Instance().Observe(preds_[train.get()], "Predictions");
monitor_.Start("GetGradient");
obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_);
obj_->GetGradient(preds_[train.get()], train->Info(), iter, &gpair_);
monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(gpair_, "Gradients");
gbm_->DoBoost(train, &gpair_, obj_.get());
gbm_->DoBoost(train.get(), &gpair_, obj_.get());
monitor_.Stop("UpdateOneIter");
}
void BoostOneIter(int iter, DMatrix* train,
void BoostOneIter(int iter, std::shared_ptr<DMatrix> train,
HostDeviceVector<GradientPair>* in_gpair) override {
monitor_.Start("BoostOneIter");
this->Configure();
@@ -723,13 +723,13 @@ class LearnerImpl : public Learner {
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
}
this->CheckDataSplitMode();
this->ValidateDMatrix(train);
this->ValidateDMatrix(train.get());
gbm_->DoBoost(train, in_gpair);
gbm_->DoBoost(train.get(), in_gpair);
monitor_.Stop("BoostOneIter");
}
std::string EvalOneIter(int iter, const std::vector<DMatrix*>& data_sets,
std::string EvalOneIter(int iter, const std::vector<std::shared_ptr<DMatrix>>& data_sets,
const std::vector<std::string>& data_names) override {
monitor_.Start("EvalOneIter");
this->Configure();
@@ -741,9 +741,9 @@ class LearnerImpl : public Learner {
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
}
for (size_t i = 0; i < data_sets.size(); ++i) {
DMatrix * dmat = data_sets[i];
DMatrix * dmat = data_sets[i].get();
this->ValidateDMatrix(dmat);
this->PredictRaw(data_sets[i], &preds_[dmat], false);
this->PredictRaw(dmat, &preds_[dmat], false);
obj_->EvalTransform(&preds_[dmat]);
for (auto& ev : metrics_) {
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
@@ -805,7 +805,7 @@ class LearnerImpl : public Learner {
return generic_parameters_;
}
void Predict(DMatrix* data, bool output_margin,
void Predict(std::shared_ptr<DMatrix> data, bool output_margin,
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
bool training,
bool pred_leaf, bool pred_contribs, bool approx_contribs,
@@ -816,14 +816,14 @@ class LearnerImpl : public Learner {
this->Configure();
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
if (pred_contribs) {
gbm_->PredictContribution(data, &out_preds->HostVector(), ntree_limit, approx_contribs);
gbm_->PredictContribution(data.get(), &out_preds->HostVector(), ntree_limit, approx_contribs);
} else if (pred_interactions) {
gbm_->PredictInteractionContributions(data, &out_preds->HostVector(), ntree_limit,
gbm_->PredictInteractionContributions(data.get(), &out_preds->HostVector(), ntree_limit,
approx_contribs);
} else if (pred_leaf) {
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);
} else {
this->PredictRaw(data, out_preds, training, ntree_limit);
this->PredictRaw(data.get(), out_preds, training, ntree_limit);
if (!output_margin) {
obj_->PredTransform(out_preds);
}