reset learner

This commit is contained in:
Hendrik Groove 2024-10-21 23:49:22 +02:00
parent 1931a70598
commit 1c666db349

View File

@ -61,8 +61,6 @@
#include "xgboost/predictor.h" // for PredictionContainer, PredictionCacheEntry #include "xgboost/predictor.h" // for PredictionContainer, PredictionCacheEntry
#include "xgboost/string_view.h" // for operator<<, StringView #include "xgboost/string_view.h" // for operator<<, StringView
#include "xgboost/task.h" // for ObjInfo #include "xgboost/task.h" // for ObjInfo
#include <iostream>
#include <exception>
namespace { namespace {
const char* kMaxDeltaStepDefaultValue = "0.7"; const char* kMaxDeltaStepDefaultValue = "0.7";
@ -1264,44 +1262,33 @@ class LearnerImpl : public LearnerIO {
return out_impl; return out_impl;
} }
void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override { void UpdateOneIter(int iter, std::shared_ptr<DMatrix> train) override {
monitor_.Start("UpdateOneIter"); monitor_.Start("UpdateOneIter");
TrainingObserver::Instance().Update(iter); TrainingObserver::Instance().Update(iter);
this->Configure();
this->Configure(); this->InitBaseScore(train.get());
this->InitBaseScore(train.get());
if (ctx_.seed_per_iteration) { if (ctx_.seed_per_iteration) {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter); common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
} }
this->ValidateDMatrix(train.get(), true); this->ValidateDMatrix(train.get(), true);
auto& predt = prediction_container_.Cache(train, ctx_.Device()); auto& predt = prediction_container_.Cache(train, ctx_.Device());
monitor_.Start("PredictRaw"); monitor_.Start("PredictRaw");
this->PredictRaw(train.get(), &predt, true, 0, 0); this->PredictRaw(train.get(), &predt, true, 0, 0);
TrainingObserver::Instance().Observe(predt.predictions, "Predictions"); TrainingObserver::Instance().Observe(predt.predictions, "Predictions");
monitor_.Stop("PredictRaw"); monitor_.Stop("PredictRaw");
monitor_.Start("GetGradient"); monitor_.Start("GetGradient");
try {
GetGradient(predt.predictions, train->Info(), iter, &gpair_); GetGradient(predt.predictions, train->Info(), iter, &gpair_);
} catch (const std::exception& e) { monitor_.Stop("GetGradient");
throw; TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients");
}
monitor_.Stop("GetGradient");
TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients");
try {
gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get()); gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get());
} catch (const std::exception& e) { monitor_.Stop("UpdateOneIter");
throw;
} }
monitor_.Stop("UpdateOneIter");
}
void BoostOneIter(int iter, std::shared_ptr<DMatrix> train, void BoostOneIter(int iter, std::shared_ptr<DMatrix> train,
linalg::Matrix<GradientPair>* in_gpair) override { linalg::Matrix<GradientPair>* in_gpair) override {
@ -1503,4 +1490,4 @@ Learner* Learner::Create(
const std::vector<std::shared_ptr<DMatrix> >& cache_data) { const std::vector<std::shared_ptr<DMatrix> >& cache_data) {
return new LearnerImpl(cache_data); return new LearnerImpl(cache_data);
} }
} // namespace xgboost } // namespace xgboost