Configuration for init estimation. (#8343)
* Configuration for init estimation. * Check whether the model needs configuration based on const attribute `ModelFitted` instead of a mutable state. * Add parameter `boost_from_average` to tell whether the user has specified base score. * Add tests.
This commit is contained in:
@@ -453,73 +453,162 @@ TEST(Learner, MultiTarget) {
|
||||
/**
|
||||
* Test the model initialization sequence is correctly performed.
|
||||
*/
|
||||
TEST(Learner, InitEstimation) {
|
||||
size_t constexpr kCols = 10;
|
||||
auto Xy = RandomDataGenerator{10, kCols, 0}.GenerateDMatrix(true);
|
||||
class InitBaseScore : public ::testing::Test {
|
||||
protected:
|
||||
std::size_t static constexpr Cols() { return 10; }
|
||||
std::shared_ptr<DMatrix> Xy_;
|
||||
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||
void SetUp() override { Xy_ = RandomDataGenerator{10, Cols(), 0}.GenerateDMatrix(true); }
|
||||
|
||||
static float GetBaseScore(Json const &config) {
|
||||
return std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
|
||||
}
|
||||
|
||||
public:
|
||||
void TestUpdateConfig() {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->SetParam("objective", "reg:absoluteerror");
|
||||
learner->UpdateOneIter(0, Xy_);
|
||||
Json config{Object{}};
|
||||
learner->SaveConfig(&config);
|
||||
auto base_score = GetBaseScore(config);
|
||||
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
|
||||
|
||||
// already initialized
|
||||
auto Xy1 = RandomDataGenerator{100, Cols(), 0}.Seed(321).GenerateDMatrix(true);
|
||||
learner->UpdateOneIter(1, Xy1);
|
||||
learner->SaveConfig(&config);
|
||||
auto base_score1 = GetBaseScore(config);
|
||||
ASSERT_EQ(base_score, base_score1);
|
||||
|
||||
Json model{Object{}};
|
||||
learner->SaveModel(&model);
|
||||
learner.reset(Learner::Create({}));
|
||||
learner->LoadModel(model);
|
||||
learner->Configure();
|
||||
learner->UpdateOneIter(2, Xy1);
|
||||
learner->SaveConfig(&config);
|
||||
auto base_score2 = GetBaseScore(config);
|
||||
ASSERT_EQ(base_score, base_score2);
|
||||
}
|
||||
|
||||
void TestBoostFromAvgParam() {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->SetParam("objective", "reg:absoluteerror");
|
||||
learner->SetParam("base_score", "1.3");
|
||||
Json config(Object{});
|
||||
learner->Configure();
|
||||
learner->SaveConfig(&config);
|
||||
|
||||
auto base_score = GetBaseScore(config);
|
||||
// no change
|
||||
ASSERT_FLOAT_EQ(base_score, 1.3);
|
||||
|
||||
HostDeviceVector<float> predt;
|
||||
learner->Predict(Xy_, false, &predt, 0, 0);
|
||||
auto h_predt = predt.ConstHostSpan();
|
||||
for (auto v : h_predt) {
|
||||
ASSERT_FLOAT_EQ(v, 1.3);
|
||||
}
|
||||
learner->UpdateOneIter(0, Xy_);
|
||||
learner->SaveConfig(&config);
|
||||
base_score = GetBaseScore(config);
|
||||
// no change
|
||||
ASSERT_FLOAT_EQ(base_score, 1.3);
|
||||
|
||||
auto from_avg = std::stoi(
|
||||
get<String const>(config["learner"]["learner_model_param"]["boost_from_average"]));
|
||||
// from_avg is disabled when base score is set
|
||||
ASSERT_EQ(from_avg, 0);
|
||||
// in the future when we can deprecate the binary model, user can set the parameter directly.
|
||||
learner->SetParam("boost_from_average", "1");
|
||||
learner->Configure();
|
||||
learner->SaveConfig(&config);
|
||||
from_avg = std::stoi(
|
||||
get<String const>(config["learner"]["learner_model_param"]["boost_from_average"]));
|
||||
ASSERT_EQ(from_avg, 1);
|
||||
}
|
||||
|
||||
void TestInitAfterLoad() {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->SetParam("objective", "reg:absoluteerror");
|
||||
learner->Configure();
|
||||
|
||||
Json model{Object{}};
|
||||
learner->SaveModel(&model);
|
||||
auto base_score = GetBaseScore(model);
|
||||
ASSERT_EQ(base_score, ObjFunction::DefaultBaseScore());
|
||||
|
||||
learner.reset(Learner::Create({Xy_}));
|
||||
learner->LoadModel(model);
|
||||
Json config(Object{});
|
||||
learner->Configure();
|
||||
learner->SaveConfig(&config);
|
||||
base_score = GetBaseScore(config);
|
||||
ASSERT_EQ(base_score, ObjFunction::DefaultBaseScore());
|
||||
|
||||
learner->UpdateOneIter(0, Xy_);
|
||||
learner->SaveConfig(&config);
|
||||
base_score = GetBaseScore(config);
|
||||
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
|
||||
}
|
||||
|
||||
void TestInitWithPredt() {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->SetParam("objective", "reg:absoluteerror");
|
||||
HostDeviceVector<float> predt;
|
||||
learner->Predict(Xy, false, &predt, 0, 0);
|
||||
learner->Predict(Xy_, false, &predt, 0, 0);
|
||||
|
||||
auto h_predt = predt.ConstHostSpan();
|
||||
for (auto v : h_predt) {
|
||||
ASSERT_EQ(v, ObjFunction::DefaultBaseScore());
|
||||
}
|
||||
Json config{Object{}};
|
||||
|
||||
Json config(Object{});
|
||||
learner->SaveConfig(&config);
|
||||
auto base_score =
|
||||
std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
|
||||
// No base score is estimated yet.
|
||||
auto base_score = GetBaseScore(config);
|
||||
ASSERT_EQ(base_score, ObjFunction::DefaultBaseScore());
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||
learner->SetParam("objective", "reg:absoluteerror");
|
||||
learner->UpdateOneIter(0, Xy);
|
||||
|
||||
HostDeviceVector<float> predt;
|
||||
learner->Predict(Xy, false, &predt, 0, 0);
|
||||
auto h_predt = predt.ConstHostSpan();
|
||||
for (auto v : h_predt) {
|
||||
ASSERT_NE(v, ObjFunction::DefaultBaseScore());
|
||||
}
|
||||
|
||||
Json config{Object{}};
|
||||
// since prediction is not used for trianing, the train procedure still runs estimation
|
||||
learner->UpdateOneIter(0, Xy_);
|
||||
learner->SaveConfig(&config);
|
||||
auto base_score =
|
||||
std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
|
||||
base_score = GetBaseScore(config);
|
||||
ASSERT_NE(base_score, ObjFunction::DefaultBaseScore());
|
||||
|
||||
ASSERT_THROW(
|
||||
{
|
||||
learner->SetParam("base_score_estimated", "1");
|
||||
learner->Configure();
|
||||
},
|
||||
dmlc::Error);
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||
void TestUpdateProcess() {
|
||||
// Check that when training continuation is performed with update, the base score is
|
||||
// not re-evaluated.
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy_})};
|
||||
learner->SetParam("objective", "reg:absoluteerror");
|
||||
learner->SetParam("base_score", "1.3");
|
||||
learner->Configure();
|
||||
HostDeviceVector<float> predt;
|
||||
learner->Predict(Xy, false, &predt, 0, 0);
|
||||
auto h_predt = predt.ConstHostSpan();
|
||||
for (auto v : h_predt) {
|
||||
ASSERT_FLOAT_EQ(v, 1.3);
|
||||
}
|
||||
learner->UpdateOneIter(0, Xy);
|
||||
Json config{Object{}};
|
||||
|
||||
learner->UpdateOneIter(0, Xy_);
|
||||
Json model{Object{}};
|
||||
learner->SaveModel(&model);
|
||||
auto base_score = GetBaseScore(model);
|
||||
|
||||
auto Xy1 = RandomDataGenerator{100, Cols(), 0}.Seed(321).GenerateDMatrix(true);
|
||||
learner.reset(Learner::Create({Xy1}));
|
||||
learner->LoadModel(model);
|
||||
learner->SetParam("process_type", "update");
|
||||
learner->SetParam("updater", "refresh");
|
||||
learner->UpdateOneIter(1, Xy1);
|
||||
|
||||
Json config(Object{});
|
||||
learner->SaveConfig(&config);
|
||||
auto base_score =
|
||||
std::stof(get<String const>(config["learner"]["learner_model_param"]["base_score"]));
|
||||
// no change
|
||||
ASSERT_FLOAT_EQ(base_score, 1.3);
|
||||
auto base_score1 = GetBaseScore(config);
|
||||
ASSERT_EQ(base_score, base_score1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(InitBaseScore, TestUpdateConfig) { this->TestUpdateConfig(); }
|
||||
|
||||
TEST_F(InitBaseScore, FromAvgParam) { this->TestBoostFromAvgParam(); }
|
||||
|
||||
TEST_F(InitBaseScore, InitAfterLoad) { this->TestInitAfterLoad(); }
|
||||
|
||||
TEST_F(InitBaseScore, InitWithPredict) { this->TestInitWithPredt(); }
|
||||
|
||||
TEST_F(InitBaseScore, UpdateProcess) { this->TestUpdateProcess(); }
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user