Pass shared pointer instead of raw pointer to Learner. (#5302)
Extracted from https://github.com/dmlc/xgboost/pull/5220 .
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2019 by Contributors
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* \file learner.cc
|
||||
* \brief Implementation of learning algorithm.
|
||||
* \author Tianqi Chen
|
||||
@@ -691,7 +691,7 @@ class LearnerImpl : public Learner {
|
||||
return gbm_->DumpModel(fmap, with_stats, format);
|
||||
}
|
||||
|
||||
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
|
||||
monitor_.Start("UpdateOneIter");
|
||||
TrainingObserver::Instance().Update(iter);
|
||||
this->Configure();
|
||||
@@ -699,23 +699,23 @@ class LearnerImpl : public Learner {
|
||||
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train);
|
||||
this->ValidateDMatrix(train.get());
|
||||
|
||||
monitor_.Start("PredictRaw");
|
||||
this->PredictRaw(train, &preds_[train], true);
|
||||
this->PredictRaw(train.get(), &preds_[train.get()], true);
|
||||
monitor_.Stop("PredictRaw");
|
||||
TrainingObserver::Instance().Observe(preds_[train], "Predictions");
|
||||
TrainingObserver::Instance().Observe(preds_[train.get()], "Predictions");
|
||||
|
||||
monitor_.Start("GetGradient");
|
||||
obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_);
|
||||
obj_->GetGradient(preds_[train.get()], train->Info(), iter, &gpair_);
|
||||
monitor_.Stop("GetGradient");
|
||||
TrainingObserver::Instance().Observe(gpair_, "Gradients");
|
||||
|
||||
gbm_->DoBoost(train, &gpair_, obj_.get());
|
||||
gbm_->DoBoost(train.get(), &gpair_, obj_.get());
|
||||
monitor_.Stop("UpdateOneIter");
|
||||
}
|
||||
|
||||
void BoostOneIter(int iter, DMatrix* train,
|
||||
void BoostOneIter(int iter, std::shared_ptr<DMatrix> train,
|
||||
HostDeviceVector<GradientPair>* in_gpair) override {
|
||||
monitor_.Start("BoostOneIter");
|
||||
this->Configure();
|
||||
@@ -723,13 +723,13 @@ class LearnerImpl : public Learner {
|
||||
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train);
|
||||
this->ValidateDMatrix(train.get());
|
||||
|
||||
gbm_->DoBoost(train, in_gpair);
|
||||
gbm_->DoBoost(train.get(), in_gpair);
|
||||
monitor_.Stop("BoostOneIter");
|
||||
}
|
||||
|
||||
std::string EvalOneIter(int iter, const std::vector<DMatrix*>& data_sets,
|
||||
std::string EvalOneIter(int iter, const std::vector<std::shared_ptr<DMatrix>>& data_sets,
|
||||
const std::vector<std::string>& data_names) override {
|
||||
monitor_.Start("EvalOneIter");
|
||||
this->Configure();
|
||||
@@ -741,9 +741,9 @@ class LearnerImpl : public Learner {
|
||||
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
|
||||
}
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
DMatrix * dmat = data_sets[i];
|
||||
DMatrix * dmat = data_sets[i].get();
|
||||
this->ValidateDMatrix(dmat);
|
||||
this->PredictRaw(data_sets[i], &preds_[dmat], false);
|
||||
this->PredictRaw(dmat, &preds_[dmat], false);
|
||||
obj_->EvalTransform(&preds_[dmat]);
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||
@@ -805,7 +805,7 @@ class LearnerImpl : public Learner {
|
||||
return generic_parameters_;
|
||||
}
|
||||
|
||||
void Predict(DMatrix* data, bool output_margin,
|
||||
void Predict(std::shared_ptr<DMatrix> data, bool output_margin,
|
||||
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
bool training,
|
||||
bool pred_leaf, bool pred_contribs, bool approx_contribs,
|
||||
@@ -816,14 +816,14 @@ class LearnerImpl : public Learner {
|
||||
this->Configure();
|
||||
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
|
||||
if (pred_contribs) {
|
||||
gbm_->PredictContribution(data, &out_preds->HostVector(), ntree_limit, approx_contribs);
|
||||
gbm_->PredictContribution(data.get(), &out_preds->HostVector(), ntree_limit, approx_contribs);
|
||||
} else if (pred_interactions) {
|
||||
gbm_->PredictInteractionContributions(data, &out_preds->HostVector(), ntree_limit,
|
||||
gbm_->PredictInteractionContributions(data.get(), &out_preds->HostVector(), ntree_limit,
|
||||
approx_contribs);
|
||||
} else if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
|
||||
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);
|
||||
} else {
|
||||
this->PredictRaw(data, out_preds, training, ntree_limit);
|
||||
this->PredictRaw(data.get(), out_preds, training, ntree_limit);
|
||||
if (!output_margin) {
|
||||
obj_->PredTransform(out_preds);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user