From 1c666db34962694b57b8c48a9f4f1a8a858062c9 Mon Sep 17 00:00:00 2001 From: Hendrik Groove Date: Mon, 21 Oct 2024 23:49:22 +0200 Subject: [PATCH] reset learner --- src/learner.cc | 51 +++++++++++++++++++------------------------------- 1 file changed, 19 insertions(+), 32 deletions(-) diff --git a/src/learner.cc b/src/learner.cc index 02b3bc569..a56829aad 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -61,8 +61,6 @@ #include "xgboost/predictor.h" // for PredictionContainer, PredictionCacheEntry #include "xgboost/string_view.h" // for operator<<, StringView #include "xgboost/task.h" // for ObjInfo -#include -#include namespace { const char* kMaxDeltaStepDefaultValue = "0.7"; @@ -1264,44 +1262,33 @@ class LearnerImpl : public LearnerIO { return out_impl; } -void UpdateOneIter(int iter, std::shared_ptr train) override { - monitor_.Start("UpdateOneIter"); - TrainingObserver::Instance().Update(iter); - - this->Configure(); - - this->InitBaseScore(train.get()); + void UpdateOneIter(int iter, std::shared_ptr train) override { + monitor_.Start("UpdateOneIter"); + TrainingObserver::Instance().Update(iter); + this->Configure(); + this->InitBaseScore(train.get()); - if (ctx_.seed_per_iteration) { - common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter); - } + if (ctx_.seed_per_iteration) { + common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter); + } - this->ValidateDMatrix(train.get(), true); + this->ValidateDMatrix(train.get(), true); - auto& predt = prediction_container_.Cache(train, ctx_.Device()); + auto& predt = prediction_container_.Cache(train, ctx_.Device()); - monitor_.Start("PredictRaw"); - this->PredictRaw(train.get(), &predt, true, 0, 0); - TrainingObserver::Instance().Observe(predt.predictions, "Predictions"); - monitor_.Stop("PredictRaw"); + monitor_.Start("PredictRaw"); + this->PredictRaw(train.get(), &predt, true, 0, 0); + TrainingObserver::Instance().Observe(predt.predictions, "Predictions"); + monitor_.Stop("PredictRaw"); - monitor_.Start("GetGradient"); - try { + monitor_.Start("GetGradient"); GetGradient(predt.predictions, train->Info(), iter, &gpair_); - } catch (const std::exception& e) { - throw; - } - monitor_.Stop("GetGradient"); - TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients"); + monitor_.Stop("GetGradient"); + TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients"); - try { gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get()); - } catch (const std::exception& e) { - throw; + monitor_.Stop("UpdateOneIter"); } - - monitor_.Stop("UpdateOneIter"); -} void BoostOneIter(int iter, std::shared_ptr train, linalg::Matrix* in_gpair) override { @@ -1503,4 +1490,4 @@ Learner* Learner::Create( const std::vector >& cache_data) { return new LearnerImpl(cache_data); } -} // namespace xgboost +} // namespace xgboost \ No newline at end of file