Configuration for init estimation. (#8343)

* Configuration for init estimation.

* Check whether the model needs configuration based on const attribute `ModelFitted`
instead of a mutable state.
* Add parameter `boost_from_average` to tell whether the user has specified base score.
* Add tests.
This commit is contained in:
Jiaming Yuan 2022-10-18 01:52:24 +08:00 committed by GitHub
parent 2176e511fc
commit 031d66ec27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 247 additions and 111 deletions

View File

@ -370,9 +370,11 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed <https://en.wikipedia.org/wiki/Gamma_distribution#Occurrence_and_applications>`_. - ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed <https://en.wikipedia.org/wiki/Gamma_distribution#Occurrence_and_applications>`_.
- ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed <https://en.wikipedia.org/wiki/Tweedie_distribution#Occurrence_and_applications>`_. - ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed <https://en.wikipedia.org/wiki/Tweedie_distribution#Occurrence_and_applications>`_.
* ``base_score`` [default=0.5] * ``base_score``
- The initial prediction score of all instances, global bias - The initial prediction score of all instances, global bias
- The parameter is automatically estimated for selected objectives before training. To
disable the estimation, specify a real number argument.
- For sufficient number of iterations, changing this value will not have too much effect. - For sufficient number of iterations, changing this value will not have too much effect.
* ``eval_metric`` [default according to objective] * ``eval_metric`` [default according to objective]

View File

@ -75,6 +75,11 @@ class GradientBooster : public Model, public Configurable {
/*! \brief Return number of boosted rounds. /*! \brief Return number of boosted rounds.
*/ */
virtual int32_t BoostedRounds() const = 0; virtual int32_t BoostedRounds() const = 0;
/**
* \brief Whether the model has already been trained. When tree booster is chosen, then
* returns true when there are existing trees.
*/
virtual bool ModelFitted() const = 0;
/*! /*!
* \brief perform update to the model(boosting) * \brief perform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features * \param p_fmat feature matrix that provide access to features

View File

@ -328,7 +328,7 @@ struct LearnerModelParam {
void Copy(LearnerModelParam const& that); void Copy(LearnerModelParam const& that);
/* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */ /* \brief Whether this parameter is initialized with LearnerModelParamLegacy. */
bool Initialized() const { return num_feature != 0; } bool Initialized() const { return num_feature != 0 && num_output_group != 0; }
}; };
} // namespace xgboost } // namespace xgboost

View File

@ -162,6 +162,10 @@ class HostDeviceVectorImpl {
if (device_ >= 0) { if (device_ >= 0) {
LazySyncHost(GPUAccess::kNone); LazySyncHost(GPUAccess::kNone);
} }
if (device_ >= 0 && device >= 0) {
CHECK_EQ(device_, device) << "New device ordinal is different from previous one.";
}
device_ = device; device_ = device;
if (device_ >= 0) { if (device_ >= 0) {
LazyResizeDevice(data_h_.size()); LazyResizeDevice(data_h_.size());

View File

@ -3,8 +3,8 @@
*/ */
#ifndef XGBOOST_COMMON_LINALG_OP_H_ #ifndef XGBOOST_COMMON_LINALG_OP_H_
#define XGBOOST_COMMON_LINALG_OP_H_ #define XGBOOST_COMMON_LINALG_OP_H_
#include <type_traits>
#include <cstdint> // std::int32_t #include <cstdint> // std::int32_t
#include <type_traits>
#include "common.h" #include "common.h"
#include "threading_utils.h" #include "threading_utils.h"
@ -43,12 +43,12 @@ void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& f
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, void* s = nullptr) { void ElementWiseKernelDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr) {
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, void* s = nullptr) { void ElementWiseTransformDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr) {
common::AssertGPUSupport(); common::AssertGPUSupport();
} }

View File

@ -95,6 +95,8 @@ class GBLinear : public GradientBooster {
return model_.num_boosted_rounds; return model_.num_boosted_rounds;
} }
bool ModelFitted() const override { return BoostedRounds() != 0; }
void Load(dmlc::Stream* fi) override { void Load(dmlc::Stream* fi) override {
model_.Load(fi); model_.Load(fi);
} }

View File

@ -252,6 +252,10 @@ class GBTree : public GradientBooster {
return model_.trees.size() / this->LayerTrees(); return model_.trees.size() / this->LayerTrees();
} }
bool ModelFitted() const override {
return !model_.trees.empty() || !model_.trees_to_update.empty();
}
void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *out_preds, void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *out_preds,
bool training, unsigned layer_begin, unsigned layer_end) override; bool training, unsigned layer_begin, unsigned layer_end) override;

View File

@ -12,6 +12,7 @@
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <algorithm> #include <algorithm>
#include <array>
#include <atomic> #include <atomic>
#include <iomanip> #include <iomanip>
#include <limits> // std::numeric_limits #include <limits> // std::numeric_limits
@ -27,7 +28,6 @@
#include "common/charconv.h" #include "common/charconv.h"
#include "common/common.h" #include "common/common.h"
#include "common/io.h" #include "common/io.h"
#include "common/linalg_op.h"
#include "common/observer.h" #include "common/observer.h"
#include "common/random.h" #include "common/random.h"
#include "common/threading_utils.h" #include "common/threading_utils.h"
@ -64,6 +64,15 @@ DECLARE_FIELD_ENUM_CLASS(xgboost::DataSplitMode);
namespace xgboost { namespace xgboost {
Learner::~Learner() = default; Learner::~Learner() = default;
namespace {
StringView ModelNotFitted() { return "Model is not yet initialized (not fitted)."; }
template <typename T>
T& UsePtr(T& ptr) { // NOLINT
CHECK(ptr);
return ptr;
}
} // anonymous namespace
/*! \brief training parameter for regression /*! \brief training parameter for regression
* *
@ -75,20 +84,28 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
/* \brief global bias */ /* \brief global bias */
bst_float base_score; bst_float base_score;
/* \brief number of features */ /* \brief number of features */
uint32_t num_feature; bst_feature_t num_feature;
/* \brief number of classes, if it is multi-class classification */ /* \brief number of classes, if it is multi-class classification */
int32_t num_class; std::int32_t num_class;
/*! \brief Model contain additional properties */ /*! \brief Model contain additional properties */
int32_t contain_extra_attrs; int32_t contain_extra_attrs;
/*! \brief Model contain eval metrics */ /*! \brief Model contain eval metrics */
int32_t contain_eval_metrics; int32_t contain_eval_metrics;
/*! \brief the version of XGBoost. */ /*! \brief the version of XGBoost. */
uint32_t major_version; std::uint32_t major_version;
uint32_t minor_version; std::uint32_t minor_version;
uint32_t num_target{1}; uint32_t num_target{1};
/**
int32_t base_score_estimated{0}; * \brief Whether we should calculate the base score from training data.
*
* This is a private parameter as we can't expose it as boolean due to binary model
* format. Exposing it as integer creates inconsistency with other parameters.
*
* Automatically disabled when base_score is specifed by user. int32 is used instead
* of bool for the ease of serialization.
*/
std::int32_t boost_from_average{true};
/*! \brief reserved field */ /*! \brief reserved field */
int reserved[25]; int reserved[25];
/*! \brief constructor */ /*! \brief constructor */
@ -98,14 +115,14 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
num_target = 1; num_target = 1;
major_version = std::get<0>(Version::Self()); major_version = std::get<0>(Version::Self());
minor_version = std::get<1>(Version::Self()); minor_version = std::get<1>(Version::Self());
base_score_estimated = 0; boost_from_average = true;
static_assert(sizeof(LearnerModelParamLegacy) == 136, static_assert(sizeof(LearnerModelParamLegacy) == 136,
"Do not change the size of this struct, as it will break binary IO."); "Do not change the size of this struct, as it will break binary IO.");
} }
// Skip other legacy fields. // Skip other legacy fields.
Json ToJson() const { Json ToJson() const {
Object obj; Json obj{Object{}};
char floats[NumericLimits<float>::kToCharsSize]; char floats[NumericLimits<float>::kToCharsSize];
auto ret = to_chars(floats, floats + NumericLimits<float>::kToCharsSize, base_score); auto ret = to_chars(floats, floats + NumericLimits<float>::kToCharsSize, base_score);
CHECK(ret.ec == std::errc{}); CHECK(ret.ec == std::errc{});
@ -120,15 +137,19 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize, ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize,
static_cast<int64_t>(num_class)); static_cast<int64_t>(num_class));
CHECK(ret.ec == std::errc()); CHECK(ret.ec == std::errc());
obj["num_class"] = obj["num_class"] = std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};
std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};
ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize, ret = to_chars(integers, integers + NumericLimits<int64_t>::kToCharsSize,
static_cast<int64_t>(num_target)); static_cast<int64_t>(num_target));
obj["num_target"] = obj["num_target"] =
std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))}; std::string{integers, static_cast<size_t>(std::distance(integers, ret.ptr))};
return Json(std::move(obj)); ret = to_chars(integers, integers + NumericLimits<std::int64_t>::kToCharsSize,
static_cast<std::int64_t>(boost_from_average));
obj["boost_from_average"] =
std::string{integers, static_cast<std::size_t>(std::distance(integers, ret.ptr))};
return obj;
} }
void FromJson(Json const& obj) { void FromJson(Json const& obj) {
auto const& j_param = get<Object const>(obj); auto const& j_param = get<Object const>(obj);
@ -139,13 +160,15 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
if (n_targets_it != j_param.cend()) { if (n_targets_it != j_param.cend()) {
m["num_target"] = get<String const>(n_targets_it->second); m["num_target"] = get<String const>(n_targets_it->second);
} }
auto bse_it = j_param.find("boost_from_average");
if (bse_it != j_param.cend()) {
m["boost_from_average"] = get<String const>(bse_it->second);
}
this->Init(m); this->Init(m);
std::string str = get<String const>(j_param.at("base_score")); std::string str = get<String const>(j_param.at("base_score"));
from_chars(str.c_str(), str.c_str() + str.size(), base_score); from_chars(str.c_str(), str.c_str() + str.size(), base_score);
// It can only be estimated during the first training, we consider it estimated afterward
base_score_estimated = 1;
} }
LearnerModelParamLegacy ByteSwap() const { LearnerModelParamLegacy ByteSwap() const {
@ -158,7 +181,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
dmlc::ByteSwap(&x.major_version, sizeof(x.major_version), 1); dmlc::ByteSwap(&x.major_version, sizeof(x.major_version), 1);
dmlc::ByteSwap(&x.minor_version, sizeof(x.minor_version), 1); dmlc::ByteSwap(&x.minor_version, sizeof(x.minor_version), 1);
dmlc::ByteSwap(&x.num_target, sizeof(x.num_target), 1); dmlc::ByteSwap(&x.num_target, sizeof(x.num_target), 1);
dmlc::ByteSwap(&x.base_score_estimated, sizeof(x.base_score_estimated), 1); dmlc::ByteSwap(&x.boost_from_average, sizeof(x.boost_from_average), 1);
dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0])); dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
return x; return x;
} }
@ -166,14 +189,13 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
template <typename Container> template <typename Container>
Args UpdateAllowUnknown(Container const& kwargs) { Args UpdateAllowUnknown(Container const& kwargs) {
// Detect whether user has made their own base score. // Detect whether user has made their own base score.
if (std::find_if(kwargs.cbegin(), kwargs.cend(), auto find_key = [&kwargs](char const* key) {
[](auto const& kv) { return kv.first == "base_score"; }) != kwargs.cend()) { return std::find_if(kwargs.cbegin(), kwargs.cend(),
base_score_estimated = true; [key](auto const& kv) { return kv.first == key; });
} };
if (std::find_if(kwargs.cbegin(), kwargs.cend(), [](auto const& kv) { auto it = find_key("base_score");
return kv.first == "base_score_estimated"; if (it != kwargs.cend()) {
}) != kwargs.cend()) { boost_from_average = false;
LOG(FATAL) << "`base_score_estimated` cannot be specified as hyper-parameter.";
} }
return dmlc::Parameter<LearnerModelParamLegacy>::UpdateAllowUnknown(kwargs); return dmlc::Parameter<LearnerModelParamLegacy>::UpdateAllowUnknown(kwargs);
} }
@ -195,7 +217,9 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
.set_default(1) .set_default(1)
.set_lower_bound(1) .set_lower_bound(1)
.describe("Number of target for multi-target regression."); .describe("Number of target for multi-target regression.");
DMLC_DECLARE_FIELD(base_score_estimated).set_default(0); DMLC_DECLARE_FIELD(boost_from_average)
.set_default(true)
.describe("Whether we should calculate the base score from training data.");
} }
}; };
@ -224,7 +248,7 @@ LearnerModelParam::LearnerModelParam(Context const* ctx, LearnerModelParamLegacy
linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(int32_t device) const { linalg::TensorView<float const, 1> LearnerModelParam::BaseScore(int32_t device) const {
// multi-class is not yet supported. // multi-class is not yet supported.
CHECK_EQ(base_score_.Size(), 1); CHECK_EQ(base_score_.Size(), 1) << ModelNotFitted();
if (device == Context::kCpuId) { if (device == Context::kCpuId) {
// Make sure that we won't run into race condition. // Make sure that we won't run into race condition.
CHECK(base_score_.Data()->HostCanRead()); CHECK(base_score_.Data()->HostCanRead());
@ -385,6 +409,21 @@ class LearnerConfiguration : public Learner {
// Initial prediction. // Initial prediction.
std::vector<std::string> metric_names_; std::vector<std::string> metric_names_;
void ConfigureModelParamWithoutBaseScore() {
// Convert mparam to learner_model_param
this->ConfigureTargets();
auto task = UsePtr(obj_)->Task();
linalg::Tensor<float, 1> base_score({1}, Ctx()->gpu_id);
auto h_base_score = base_score.HostView();
// transform to margin
h_base_score(0) = obj_->ProbToMargin(mparam_.base_score);
// move it to model param, which is shared with all other components.
learner_model_param_ = LearnerModelParam(Ctx(), mparam_, std::move(base_score), task);
CHECK(learner_model_param_.Initialized());
CHECK_NE(learner_model_param_.BaseScore(Ctx()).Size(), 0);
}
/** /**
* \brief Calculate the `base_score` based on input data. * \brief Calculate the `base_score` based on input data.
* *
@ -403,38 +442,24 @@ class LearnerConfiguration : public Learner {
// - model loaded from new binary or JSON. // - model loaded from new binary or JSON.
// - model is created from scratch. // - model is created from scratch.
// - model is configured second time due to change of parameter // - model is configured second time due to change of parameter
CHECK(obj_); if (!learner_model_param_.Initialized()) {
if (!mparam_.base_score_estimated) { this->ConfigureModelParamWithoutBaseScore();
}
if (mparam_.boost_from_average && !UsePtr(gbm_)->ModelFitted()) {
if (p_fmat) { if (p_fmat) {
auto const& info = p_fmat->Info();
info.Validate(Ctx()->gpu_id);
// We estimate it from input data. // We estimate it from input data.
linalg::Tensor<float, 1> base_score; linalg::Tensor<float, 1> base_score;
obj_->InitEstimation(p_fmat->Info(), &base_score); UsePtr(obj_)->InitEstimation(info, &base_score);
mparam_.base_score = base_score(0); mparam_.base_score = base_score(0);
CHECK(!std::isnan(mparam_.base_score)); CHECK(!std::isnan(mparam_.base_score));
} else {
mparam_.base_score = ObjFunction::DefaultBaseScore();
} }
mparam_.base_score_estimated = true;
// Update the shared model parameter // Update the shared model parameter
this->ConfigureModelParam(); this->ConfigureModelParamWithoutBaseScore();
} }
} CHECK(!std::isnan(mparam_.base_score));
CHECK(!std::isinf(mparam_.base_score));
// Convert mparam to learner_model_param
void ConfigureModelParam() {
this->ConfigureTargets();
CHECK(obj_);
auto task = obj_->Task();
linalg::Tensor<float, 1> base_score({1}, Ctx()->gpu_id);
auto h_base_score = base_score.HostView();
// transform to margin
h_base_score(0) = obj_->ProbToMargin(mparam_.base_score);
// move it to model param, which is shared with all other components.
learner_model_param_ = LearnerModelParam(Ctx(), mparam_, std::move(base_score), task);
CHECK(learner_model_param_.Initialized());
CHECK_NE(learner_model_param_.BaseScore(Ctx()).Size(), 0);
} }
public: public:
@ -496,7 +521,8 @@ class LearnerConfiguration : public Learner {
learner_model_param_.task = obj_->Task(); // required by gbm configuration. learner_model_param_.task = obj_->Task(); // required by gbm configuration.
this->ConfigureGBM(old_tparam, args); this->ConfigureGBM(old_tparam, args);
ctx_.ConfigureGpuId(this->gbm_->UseGPU()); ctx_.ConfigureGpuId(this->gbm_->UseGPU());
this->ConfigureModelParam();
this->ConfigureModelParamWithoutBaseScore();
this->ConfigureMetrics(args); this->ConfigureMetrics(args);
@ -510,8 +536,8 @@ class LearnerConfiguration : public Learner {
} }
void CheckModelInitialized() const { void CheckModelInitialized() const {
CHECK(learner_model_param_.Initialized()) << "Model not yet initialized."; CHECK(learner_model_param_.Initialized()) << ModelNotFitted();
CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0); CHECK_NE(learner_model_param_.BaseScore(this->Ctx()).Size(), 0) << ModelNotFitted();
} }
virtual PredictionContainer* GetPredictionCache() const { virtual PredictionContainer* GetPredictionCache() const {
@ -1318,8 +1344,6 @@ class LearnerImpl : public LearnerIO {
HostDeviceVector<GradientPair>* in_gpair) override { HostDeviceVector<GradientPair>* in_gpair) override {
monitor_.Start("BoostOneIter"); monitor_.Start("BoostOneIter");
this->Configure(); this->Configure();
// Should have been set to default in the first prediction.
CHECK(mparam_.base_score_estimated);
if (ctx_.seed_per_iteration) { if (ctx_.seed_per_iteration) {
common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter); common::GlobalRandom().seed(ctx_.seed * kRandSeedMagic + iter);
@ -1380,7 +1404,9 @@ class LearnerImpl : public LearnerIO {
static_cast<int>(pred_interactions) + static_cast<int>(pred_interactions) +
static_cast<int>(pred_contribs); static_cast<int>(pred_contribs);
this->Configure(); this->Configure();
if (training) {
this->InitBaseScore(nullptr); this->InitBaseScore(nullptr);
}
this->CheckModelInitialized(); this->CheckModelInitialized();
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time."; CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
@ -1425,7 +1451,6 @@ class LearnerImpl : public LearnerIO {
HostDeviceVector<bst_float>** out_preds, uint32_t iteration_begin, HostDeviceVector<bst_float>** out_preds, uint32_t iteration_begin,
uint32_t iteration_end) override { uint32_t iteration_end) override {
this->Configure(); this->Configure();
this->InitBaseScore(nullptr);
this->CheckModelInitialized(); this->CheckModelInitialized();
auto& out_predictions = this->GetThreadLocal().prediction_entry; auto& out_predictions = this->GetThreadLocal().prediction_entry;

View File

@ -723,10 +723,15 @@ class MeanAbsoluteError : public ObjFunction {
out(0) = common::Median(ctx_, info.labels, info.weights_) * w; out(0) = common::Median(ctx_, info.labels, info.weights_) * w;
} }
// Weighted average base score across all workers
collective::Allreduce<collective::Operation::kSum>(out.Values().data(), out.Values().size()); collective::Allreduce<collective::Operation::kSum>(out.Values().data(), out.Values().size());
collective::Allreduce<collective::Operation::kSum>(&w, 1); collective::Allreduce<collective::Operation::kSum>(&w, 1);
if (common::CloseTo(w, 0.0)) {
// Mostly for handling empty dataset test.
LOG(WARNING) << "Sum of weights is close to 0.0, skipping base score estimation.";
out(0) = ObjFunction::DefaultBaseScore();
return;
}
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out), std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
[w](float v) { return v / w; }); [w](float v) { return v / w; });
} }

View File

@ -453,73 +453,162 @@ TEST(Learner, MultiTarget) {
/** /**
* Test the model initialization sequence is correctly performed. * Test the model initialization sequence is correctly performed.
*/ */
TEST(Learner, InitEstimation) { class InitBaseScore : public ::testing::Test {
size_t constexpr kCols = 10; protected:
auto Xy = RandomDataGenerator{10, kCols, 0}.GenerateDMatrix(true); std::size_t static constexpr Cols() { return 10; }
std::shared_ptr<DMatrix> Xy_;
{ void SetUp() override { Xy_ = RandomDataGenerator{10, Cols(), 0}.GenerateDMatrix(true); }
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
static float GetBaseScore(Json const &config) {
return std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
}
public:
void TestUpdateConfig() {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("objective", "reg:absoluteerror");
learner->UpdateOneIter(0, Xy_);
Json config{Object{}};
learner->SaveConfig(&config);
auto base_score = GetBaseScore(config);
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
// already initialized
auto Xy1 = RandomDataGenerator{100, Cols(), 0}.Seed(321).GenerateDMatrix(true);
learner->UpdateOneIter(1, Xy1);
learner->SaveConfig(&config);
auto base_score1 = GetBaseScore(config);
ASSERT_EQ(base_score, base_score1);
Json model{Object{}};
learner->SaveModel(&model);
learner.reset(Learner::Create({}));
learner->LoadModel(model);
learner->Configure();
learner->UpdateOneIter(2, Xy1);
learner->SaveConfig(&config);
auto base_score2 = GetBaseScore(config);
ASSERT_EQ(base_score, base_score2);
}
void TestBoostFromAvgParam() {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("objective", "reg:absoluteerror");
learner->SetParam("base_score", "1.3");
Json config(Object{});
learner->Configure();
learner->SaveConfig(&config);
auto base_score = GetBaseScore(config);
// no change
ASSERT_FLOAT_EQ(base_score, 1.3);
HostDeviceVector<float> predt;
learner->Predict(Xy_, false, &predt, 0, 0);
auto h_predt = predt.ConstHostSpan();
for (auto v : h_predt) {
ASSERT_FLOAT_EQ(v, 1.3);
}
learner->UpdateOneIter(0, Xy_);
learner->SaveConfig(&config);
base_score = GetBaseScore(config);
// no change
ASSERT_FLOAT_EQ(base_score, 1.3);
auto from_avg = std::stoi(
get<String const>(config["learner"]["learner_model_param"]["boost_from_average"]));
// from_avg is disabled when base score is set
ASSERT_EQ(from_avg, 0);
// in the future when we can deprecate the binary model, user can set the parameter directly.
learner->SetParam("boost_from_average", "1");
learner->Configure();
learner->SaveConfig(&config);
from_avg = std::stoi(
get<String const>(config["learner"]["learner_model_param"]["boost_from_average"]));
ASSERT_EQ(from_avg, 1);
}
void TestInitAfterLoad() {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("objective", "reg:absoluteerror"); learner->SetParam("objective", "reg:absoluteerror");
learner->Configure(); learner->Configure();
Json model{Object{}};
learner->SaveModel(&model);
auto base_score = GetBaseScore(model);
ASSERT_EQ(base_score, ObjFunction::DefaultBaseScore());
learner.reset(Learner::Create({Xy_}));
learner->LoadModel(model);
Json config(Object{});
learner->Configure();
learner->SaveConfig(&config);
base_score = GetBaseScore(config);
ASSERT_EQ(base_score, ObjFunction::DefaultBaseScore());
learner->UpdateOneIter(0, Xy_);
learner->SaveConfig(&config);
base_score = GetBaseScore(config);
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
}
void TestInitWithPredt() {
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("objective", "reg:absoluteerror");
HostDeviceVector<float> predt; HostDeviceVector<float> predt;
learner->Predict(Xy, false, &predt, 0, 0); learner->Predict(Xy_, false, &predt, 0, 0);
auto h_predt = predt.ConstHostSpan(); auto h_predt = predt.ConstHostSpan();
for (auto v : h_predt) { for (auto v : h_predt) {
ASSERT_EQ(v, ObjFunction::DefaultBaseScore()); ASSERT_EQ(v, ObjFunction::DefaultBaseScore());
} }
Json config{Object{}};
Json config(Object{});
learner->SaveConfig(&config); learner->SaveConfig(&config);
auto base_score = auto base_score = GetBaseScore(config);
std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
// No base score is estimated yet.
ASSERT_EQ(base_score, ObjFunction::DefaultBaseScore()); ASSERT_EQ(base_score, ObjFunction::DefaultBaseScore());
}
{ // since prediction is not used for trianing, the train procedure still runs estimation
std::unique_ptr<Learner> learner{Learner::Create({Xy})}; learner->UpdateOneIter(0, Xy_);
learner->SetParam("objective", "reg:absoluteerror");
learner->UpdateOneIter(0, Xy);
HostDeviceVector<float> predt;
learner->Predict(Xy, false, &predt, 0, 0);
auto h_predt = predt.ConstHostSpan();
for (auto v : h_predt) {
ASSERT_NE(v, ObjFunction::DefaultBaseScore());
}
Json config{Object{}};
learner->SaveConfig(&config); learner->SaveConfig(&config);
auto base_score = base_score = GetBaseScore(config);
std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore()); ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
ASSERT_THROW(
{
learner->SetParam("base_score_estimated", "1");
learner->Configure();
},
dmlc::Error);
} }
{ void TestUpdateProcess() {
std::unique_ptr<Learner> learner{Learner::Create({Xy})}; // Check that when training continuation is performed with update, the base score is
// not re-evaluated.
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
learner->SetParam("objective", "reg:absoluteerror"); learner->SetParam("objective", "reg:absoluteerror");
learner->SetParam("base_score", "1.3");
learner->Configure(); learner->Configure();
HostDeviceVector<float> predt;
learner->Predict(Xy, false, &predt, 0, 0); learner->UpdateOneIter(0, Xy_);
auto h_predt = predt.ConstHostSpan(); Json model{Object{}};
for (auto v : h_predt) { learner->SaveModel(&model);
ASSERT_FLOAT_EQ(v, 1.3); auto base_score = GetBaseScore(model);
}
learner->UpdateOneIter(0, Xy); auto Xy1 = RandomDataGenerator{100, Cols(), 0}.Seed(321).GenerateDMatrix(true);
Json config{Object{}}; learner.reset(Learner::Create({Xy1}));
learner->LoadModel(model);
learner->SetParam("process_type", "update");
learner->SetParam("updater", "refresh");
learner->UpdateOneIter(1, Xy1);
Json config(Object{});
learner->SaveConfig(&config); learner->SaveConfig(&config);
auto base_score = auto base_score1 = GetBaseScore(config);
std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"])); ASSERT_EQ(base_score, base_score1);
// no change
ASSERT_FLOAT_EQ(base_score, 1.3);
}
} }
};
TEST_F(InitBaseScore, TestUpdateConfig) { this->TestUpdateConfig(); }
TEST_F(InitBaseScore, FromAvgParam) { this->TestBoostFromAvgParam(); }
TEST_F(InitBaseScore, InitAfterLoad) { this->TestInitAfterLoad(); }
TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); }
TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); }
} // namespace xgboost } // namespace xgboost