reset learner
This commit is contained in:
parent
1931a70598
commit
1c666db349
@ -61,8 +61,6 @@
|
|||||||
#include "xgboost/predictor.h" // for PredictionContainer, PredictionCacheEntry
|
#include "xgboost/predictor.h" // for PredictionContainer, PredictionCacheEntry
|
||||||
#include "xgboost/string_view.h" // for operator<<, StringView
|
#include "xgboost/string_view.h" // for operator<<, StringView
|
||||||
#include "xgboost/task.h" // for ObjInfo
|
#include "xgboost/task.h" // for ObjInfo
|
||||||
#include <iostream>
|
|
||||||
#include <exception>
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
const char* kMaxDeltaStepDefaultValue = "0.7";
|
const char* kMaxDeltaStepDefaultValue = "0.7";
|
||||||
@ -1264,45 +1262,34 @@ class LearnerImpl : public LearnerIO {
|
|||||||
return out_impl;
|
return out_impl;
|
||||||
}
|
}
|
||||||
|
|
||||||
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
|
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
|
||||||
monitor_.Start("UpdateOneIter");
|
monitor_.Start("UpdateOneIter");
|
||||||
TrainingObserver::Instance().Update(iter);
|
TrainingObserver::Instance().Update(iter);
|
||||||
|
this->Configure();
|
||||||
|
this->InitBaseScore(train.get());
|
||||||
|
|
||||||
this->Configure();
|
if (ctx_.seed_per_iteration) {
|
||||||
|
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
|
||||||
|
}
|
||||||
|
|
||||||
this->InitBaseScore(train.get());
|
this->ValidateDMatrix(train.get(), true);
|
||||||
|
|
||||||
if (ctx_.seed_per_iteration) {
|
auto& predt = prediction_container_.Cache(train, ctx_.Device());
|
||||||
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
|
|
||||||
}
|
|
||||||
|
|
||||||
this->ValidateDMatrix(train.get(), true);
|
monitor_.Start("PredictRaw");
|
||||||
|
this->PredictRaw(train.get(), &predt, true, 0, 0);
|
||||||
|
TrainingObserver::Instance().Observe(predt.predictions, "Predictions");
|
||||||
|
monitor_.Stop("PredictRaw");
|
||||||
|
|
||||||
auto& predt = prediction_container_.Cache(train, ctx_.Device());
|
monitor_.Start("GetGradient");
|
||||||
|
|
||||||
monitor_.Start("PredictRaw");
|
|
||||||
this->PredictRaw(train.get(), &predt, true, 0, 0);
|
|
||||||
TrainingObserver::Instance().Observe(predt.predictions, "Predictions");
|
|
||||||
monitor_.Stop("PredictRaw");
|
|
||||||
|
|
||||||
monitor_.Start("GetGradient");
|
|
||||||
try {
|
|
||||||
GetGradient(predt.predictions, train->Info(), iter, &gpair_);
|
GetGradient(predt.predictions, train->Info(), iter, &gpair_);
|
||||||
} catch (const std::exception& e) {
|
monitor_.Stop("GetGradient");
|
||||||
throw;
|
TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients");
|
||||||
}
|
|
||||||
monitor_.Stop("GetGradient");
|
|
||||||
TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients");
|
|
||||||
|
|
||||||
try {
|
|
||||||
gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get());
|
gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get());
|
||||||
} catch (const std::exception& e) {
|
monitor_.Stop("UpdateOneIter");
|
||||||
throw;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
monitor_.Stop("UpdateOneIter");
|
|
||||||
}
|
|
||||||
|
|
||||||
void BoostOneIter(int iter, std::shared_ptr<DMatrix> train,
|
void BoostOneIter(int iter, std::shared_ptr<DMatrix> train,
|
||||||
linalg::Matrix<GradientPair>* in_gpair) override {
|
linalg::Matrix<GradientPair>* in_gpair) override {
|
||||||
monitor_.Start("BoostOneIter");
|
monitor_.Start("BoostOneIter");
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user