Pass shared pointer instead of raw pointer to Learner. (#5302)
Extracted from https://github.com/dmlc/xgboost/pull/5220 .
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2014-2019 by Contributors
|
||||
// Copyright (c) 2014-2020 by Contributors
|
||||
#include <dmlc/thread_local.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <rabit/c_api.h>
|
||||
@@ -498,7 +498,7 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
|
||||
auto *dtr =
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||
|
||||
bst->UpdateOneIter(iter, dtr->get());
|
||||
bst->UpdateOneIter(iter, *dtr);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -519,7 +519,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
|
||||
tmp_gpair_h[i] = GradientPair(grad[i], hess[i]);
|
||||
}
|
||||
|
||||
bst->BoostOneIter(0, dtr->get(), &tmp_gpair);
|
||||
bst->BoostOneIter(0, *dtr, &tmp_gpair);
|
||||
API_END();
|
||||
}
|
||||
|
||||
@@ -533,11 +533,11 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Learner*>(handle);
|
||||
std::vector<DMatrix*> data_sets;
|
||||
std::vector<std::shared_ptr<DMatrix>> data_sets;
|
||||
std::vector<std::string> data_names;
|
||||
|
||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||
data_sets.push_back(static_cast<std::shared_ptr<DMatrix>*>(dmats[i])->get());
|
||||
data_sets.push_back(*static_cast<std::shared_ptr<DMatrix>*>(dmats[i]));
|
||||
data_names.emplace_back(evnames[i]);
|
||||
}
|
||||
|
||||
@@ -560,7 +560,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
auto *bst = static_cast<Learner*>(handle);
|
||||
HostDeviceVector<bst_float> tmp_preds;
|
||||
bst->Predict(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
|
||||
*static_cast<std::shared_ptr<DMatrix>*>(dmat),
|
||||
(option_mask & 1) != 0,
|
||||
&tmp_preds, ntree_limit,
|
||||
static_cast<bool>(training),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2019 by Contributors
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* \file cli_main.cc
|
||||
* \brief The command line interface program of xgboost.
|
||||
* This file is not included in dynamic library.
|
||||
@@ -165,7 +165,7 @@ void CLITrain(const CLIParam& param) {
|
||||
param.dsplit == 2));
|
||||
std::vector<std::shared_ptr<DMatrix> > deval;
|
||||
std::vector<std::shared_ptr<DMatrix> > cache_mats;
|
||||
std::vector<DMatrix*> eval_datasets;
|
||||
std::vector<std::shared_ptr<DMatrix>> eval_datasets;
|
||||
cache_mats.push_back(dtrain);
|
||||
for (size_t i = 0; i < param.eval_data_names.size(); ++i) {
|
||||
deval.emplace_back(
|
||||
@@ -173,12 +173,12 @@ void CLITrain(const CLIParam& param) {
|
||||
param.eval_data_paths[i],
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
param.dsplit == 2)));
|
||||
eval_datasets.push_back(deval.back().get());
|
||||
eval_datasets.push_back(deval.back());
|
||||
cache_mats.push_back(deval.back());
|
||||
}
|
||||
std::vector<std::string> eval_data_names = param.eval_data_names;
|
||||
if (param.eval_train) {
|
||||
eval_datasets.push_back(dtrain.get());
|
||||
eval_datasets.push_back(dtrain);
|
||||
eval_data_names.emplace_back("train");
|
||||
}
|
||||
// initialize the learner.
|
||||
@@ -203,7 +203,7 @@ void CLITrain(const CLIParam& param) {
|
||||
double elapsed = dmlc::GetTime() - start;
|
||||
if (version % 2 == 0) {
|
||||
LOG(INFO) << "boosting round " << i << ", " << elapsed << " sec elapsed";
|
||||
learner->UpdateOneIter(i, dtrain.get());
|
||||
learner->UpdateOneIter(i, dtrain);
|
||||
if (learner->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(learner.get());
|
||||
} else {
|
||||
@@ -305,7 +305,7 @@ void CLIPredict(const CLIParam& param) {
|
||||
CHECK_NE(param.test_path, "NULL")
|
||||
<< "Test dataset parameter test:data must be specified.";
|
||||
// load data
|
||||
std::unique_ptr<DMatrix> dtest(
|
||||
std::shared_ptr<DMatrix> dtest(
|
||||
DMatrix::Load(
|
||||
param.test_path,
|
||||
ConsoleLogger::GlobalVerbosity() > ConsoleLogger::DefaultVerbosity(),
|
||||
@@ -321,7 +321,7 @@ void CLIPredict(const CLIParam& param) {
|
||||
|
||||
LOG(INFO) << "start prediction...";
|
||||
HostDeviceVector<bst_float> preds;
|
||||
learner->Predict(dtest.get(), param.pred_margin, &preds, param.ntree_limit);
|
||||
learner->Predict(dtest, param.pred_margin, &preds, param.ntree_limit);
|
||||
LOG(CONSOLE) << "writing prediction to " << param.name_pred;
|
||||
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2019 XGBoost contributors
|
||||
* Copyright 2017-2020 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <thrust/device_ptr.h>
|
||||
@@ -9,7 +9,6 @@
|
||||
#include <thrust/system_error.h>
|
||||
#include <thrust/logical.h>
|
||||
|
||||
#include <omp.h>
|
||||
#include <rabit/rabit.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/util_allocator.cuh>
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
/*!
|
||||
* Copyright 2019 XGBoost contributors
|
||||
* Copyright 2019-2020 XGBoost contributors
|
||||
* \file observer.h
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_OBSERVER_H_
|
||||
#define XGBOOST_COMMON_OBSERVER_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -63,7 +64,8 @@ class TrainingObserver {
|
||||
}
|
||||
/*\brief Observe data hosted by `std::vector'. */
|
||||
template <typename T>
|
||||
void Observe(std::vector<T> const& h_vec, std::string name) const {
|
||||
void Observe(std::vector<T> const& h_vec, std::string name,
|
||||
size_t n = std::numeric_limits<std::size_t>::max()) const {
|
||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||
OBSERVER_PRINT << "Procedure: " << name << OBSERVER_ENDL;
|
||||
|
||||
@@ -72,20 +74,25 @@ class TrainingObserver {
|
||||
if (i % 8 == 0) {
|
||||
OBSERVER_PRINT << OBSERVER_NEWLINE;
|
||||
}
|
||||
if ((i + 1) == n) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
OBSERVER_PRINT << OBSERVER_ENDL;
|
||||
}
|
||||
/*\brief Observe data hosted by `HostDeviceVector'. */
|
||||
template <typename T>
|
||||
void Observe(HostDeviceVector<T> const& vec, std::string name) const {
|
||||
void Observe(HostDeviceVector<T> const& vec, std::string name,
|
||||
size_t n = std::numeric_limits<std::size_t>::max()) const {
|
||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||
auto const& h_vec = vec.HostVector();
|
||||
this->Observe(h_vec, name);
|
||||
this->Observe(h_vec, name, n);
|
||||
}
|
||||
template <typename T>
|
||||
void Observe(HostDeviceVector<T>* vec, std::string name) const {
|
||||
void Observe(HostDeviceVector<T>* vec, std::string name,
|
||||
size_t n = std::numeric_limits<std::size_t>::max()) const {
|
||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||
this->Observe(*vec, name);
|
||||
this->Observe(*vec, name, n);
|
||||
}
|
||||
|
||||
/*\brief Observe objects with `XGBoostParamer' type. */
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014 by Contributors
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* \file gblinear.cc
|
||||
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
|
||||
* the update rule is parallel coordinate descent (shotgun)
|
||||
@@ -239,7 +239,7 @@ class GBLinear : public GradientBooster {
|
||||
void PredictBatchInternal(DMatrix *p_fmat,
|
||||
std::vector<bst_float> *out_preds) {
|
||||
monitor_.Start("PredictBatchInternal");
|
||||
model_.LazyInitModel();
|
||||
model_.LazyInitModel();
|
||||
std::vector<bst_float> &preds = *out_preds;
|
||||
const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector();
|
||||
// start collecting the prediction
|
||||
@@ -250,6 +250,9 @@ class GBLinear : public GradientBooster {
|
||||
// k is number of group
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<omp_ulong>(batch.Size());
|
||||
if (base_margin.size() != 0) {
|
||||
CHECK_EQ(base_margin.size(), nsize * ngroup);
|
||||
}
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < nsize; ++i) {
|
||||
const size_t ridx = batch.base_rowid + i;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2019 by Contributors
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* \file learner.cc
|
||||
* \brief Implementation of learning algorithm.
|
||||
* \author Tianqi Chen
|
||||
@@ -691,7 +691,7 @@ class LearnerImpl : public Learner {
|
||||
return gbm_->DumpModel(fmap, with_stats, format);
|
||||
}
|
||||
|
||||
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
|
||||
monitor_.Start("UpdateOneIter");
|
||||
TrainingObserver::Instance().Update(iter);
|
||||
this->Configure();
|
||||
@@ -699,23 +699,23 @@ class LearnerImpl : public Learner {
|
||||
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train);
|
||||
this->ValidateDMatrix(train.get());
|
||||
|
||||
monitor_.Start("PredictRaw");
|
||||
this->PredictRaw(train, &preds_[train], true);
|
||||
this->PredictRaw(train.get(), &preds_[train.get()], true);
|
||||
monitor_.Stop("PredictRaw");
|
||||
TrainingObserver::Instance().Observe(preds_[train], "Predictions");
|
||||
TrainingObserver::Instance().Observe(preds_[train.get()], "Predictions");
|
||||
|
||||
monitor_.Start("GetGradient");
|
||||
obj_->GetGradient(preds_[train], train->Info(), iter, &gpair_);
|
||||
obj_->GetGradient(preds_[train.get()], train->Info(), iter, &gpair_);
|
||||
monitor_.Stop("GetGradient");
|
||||
TrainingObserver::Instance().Observe(gpair_, "Gradients");
|
||||
|
||||
gbm_->DoBoost(train, &gpair_, obj_.get());
|
||||
gbm_->DoBoost(train.get(), &gpair_, obj_.get());
|
||||
monitor_.Stop("UpdateOneIter");
|
||||
}
|
||||
|
||||
void BoostOneIter(int iter, DMatrix* train,
|
||||
void BoostOneIter(int iter, std::shared_ptr<DMatrix> train,
|
||||
HostDeviceVector<GradientPair>* in_gpair) override {
|
||||
monitor_.Start("BoostOneIter");
|
||||
this->Configure();
|
||||
@@ -723,13 +723,13 @@ class LearnerImpl : public Learner {
|
||||
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train);
|
||||
this->ValidateDMatrix(train.get());
|
||||
|
||||
gbm_->DoBoost(train, in_gpair);
|
||||
gbm_->DoBoost(train.get(), in_gpair);
|
||||
monitor_.Stop("BoostOneIter");
|
||||
}
|
||||
|
||||
std::string EvalOneIter(int iter, const std::vector<DMatrix*>& data_sets,
|
||||
std::string EvalOneIter(int iter, const std::vector<std::shared_ptr<DMatrix>>& data_sets,
|
||||
const std::vector<std::string>& data_names) override {
|
||||
monitor_.Start("EvalOneIter");
|
||||
this->Configure();
|
||||
@@ -741,9 +741,9 @@ class LearnerImpl : public Learner {
|
||||
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
|
||||
}
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
DMatrix * dmat = data_sets[i];
|
||||
DMatrix * dmat = data_sets[i].get();
|
||||
this->ValidateDMatrix(dmat);
|
||||
this->PredictRaw(data_sets[i], &preds_[dmat], false);
|
||||
this->PredictRaw(dmat, &preds_[dmat], false);
|
||||
obj_->EvalTransform(&preds_[dmat]);
|
||||
for (auto& ev : metrics_) {
|
||||
os << '\t' << data_names[i] << '-' << ev->Name() << ':'
|
||||
@@ -805,7 +805,7 @@ class LearnerImpl : public Learner {
|
||||
return generic_parameters_;
|
||||
}
|
||||
|
||||
void Predict(DMatrix* data, bool output_margin,
|
||||
void Predict(std::shared_ptr<DMatrix> data, bool output_margin,
|
||||
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
bool training,
|
||||
bool pred_leaf, bool pred_contribs, bool approx_contribs,
|
||||
@@ -816,14 +816,14 @@ class LearnerImpl : public Learner {
|
||||
this->Configure();
|
||||
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
|
||||
if (pred_contribs) {
|
||||
gbm_->PredictContribution(data, &out_preds->HostVector(), ntree_limit, approx_contribs);
|
||||
gbm_->PredictContribution(data.get(), &out_preds->HostVector(), ntree_limit, approx_contribs);
|
||||
} else if (pred_interactions) {
|
||||
gbm_->PredictInteractionContributions(data, &out_preds->HostVector(), ntree_limit,
|
||||
gbm_->PredictInteractionContributions(data.get(), &out_preds->HostVector(), ntree_limit,
|
||||
approx_contribs);
|
||||
} else if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data, &out_preds->HostVector(), ntree_limit);
|
||||
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);
|
||||
} else {
|
||||
this->PredictRaw(data, out_preds, training, ntree_limit);
|
||||
this->PredictRaw(data.get(), out_preds, training, ntree_limit);
|
||||
if (!output_margin) {
|
||||
obj_->PredTransform(out_preds);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user