Unify evaluation functions. (#6037)
This commit is contained in:
@@ -11,6 +11,51 @@
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
void CompareJSON(Json l, Json r) {
|
||||
switch (l.GetValue().Type()) {
|
||||
case Value::ValueKind::kString: {
|
||||
ASSERT_EQ(l, r);
|
||||
break;
|
||||
}
|
||||
case Value::ValueKind::kNumber: {
|
||||
ASSERT_NEAR(get<Number>(l), get<Number>(r), kRtEps);
|
||||
break;
|
||||
}
|
||||
case Value::ValueKind::kInteger: {
|
||||
ASSERT_EQ(l, r);
|
||||
break;
|
||||
}
|
||||
case Value::ValueKind::kObject: {
|
||||
auto const &l_obj = get<Object const>(l);
|
||||
auto const &r_obj = get<Object const>(r);
|
||||
ASSERT_EQ(l_obj.size(), r_obj.size());
|
||||
|
||||
for (auto const& kv : l_obj) {
|
||||
ASSERT_NE(r_obj.find(kv.first), r_obj.cend());
|
||||
CompareJSON(l_obj.at(kv.first), r_obj.at(kv.first));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Value::ValueKind::kArray: {
|
||||
auto const& l_arr = get<Array const>(l);
|
||||
auto const& r_arr = get<Array const>(r);
|
||||
ASSERT_EQ(l_arr.size(), r_arr.size());
|
||||
for (size_t i = 0; i < l_arr.size(); ++i) {
|
||||
CompareJSON(l_arr[i], r_arr[i]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Value::ValueKind::kBoolean: {
|
||||
ASSERT_EQ(l, r);
|
||||
break;
|
||||
}
|
||||
case Value::ValueKind::kNull: {
|
||||
ASSERT_EQ(l, r);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr<DMatrix> p_dmat) {
|
||||
for (auto& batch : p_dmat->GetBatches<SparsePage>()) {
|
||||
batch.data.HostVector();
|
||||
@@ -104,7 +149,7 @@ void TestLearnerSerialization(Args args, FeatureMap const& fmap, std::shared_ptr
|
||||
|
||||
Json m_0 = Json::Load(StringView{continued_model.c_str(), continued_model.size()});
|
||||
Json m_1 = Json::Load(StringView{model_at_2kiter.c_str(), model_at_2kiter.size()});
|
||||
ASSERT_EQ(m_0, m_1);
|
||||
CompareJSON(m_0, m_1);
|
||||
}
|
||||
|
||||
// Test training continuation with data from device.
|
||||
@@ -323,7 +368,7 @@ TEST_F(SerializationTest, ConfigurationCount) {
|
||||
occureences ++;
|
||||
pos += target.size();
|
||||
}
|
||||
ASSERT_EQ(occureences, 2);
|
||||
ASSERT_EQ(occureences, 2ul);
|
||||
|
||||
xgboost::ConsoleLogger::Configure({{"verbosity", "2"}});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user