diff --git a/include/xgboost/json.h b/include/xgboost/json.h index 97dc2e11d..cdd86d135 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -202,6 +202,17 @@ class JsonInteger : public Value { typename std::enable_if::value>::type* = nullptr> JsonInteger(IntT value) : Value(ValueKind::Integer), // NOLINT integer_{static_cast(value)} {} + template ::value>::type* = nullptr> + JsonInteger(IntT value) : Value(ValueKind::Integer), // NOLINT + integer_{static_cast(value)} {} + template ::value && + !std::is_same::value>::type * = nullptr> + JsonInteger(IntT value) // NOLINT + : Value(ValueKind::Integer), + integer_{static_cast(value)} {} Json& operator[](std::string const & key) override; Json& operator[](int ind) override; @@ -533,8 +544,8 @@ using Null = JsonNull; // Utils tailored for XGBoost. -template -Object toJson(XGBoostParameter const& param) { +template +Object toJson(Parameter const& param) { Object obj; for (auto const& kv : param.__DICT__()) { obj[kv.first] = kv.second; @@ -542,8 +553,8 @@ Object toJson(XGBoostParameter const& param) { return obj; } -template -void fromJson(Json const& obj, XGBoostParameter* param) { +template +void fromJson(Json const& obj, Parameter* param) { auto const& j_param = get(obj); std::map m; for (auto const& kv : j_param) { diff --git a/src/common/json.cc b/src/common/json.cc index 2c8000787..915619924 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -243,7 +243,7 @@ Json& JsonNumber::operator[](int ind) { bool JsonNumber::operator==(Value const& rhs) const { if (!IsA(&rhs)) { return false; } - return number_ == Cast(&rhs)->getNumber(); + return std::abs(number_ - Cast(&rhs)->getNumber()) < kRtEps; } Value & JsonNumber::operator=(Value const &rhs) { @@ -504,7 +504,10 @@ Json JsonReader::ParseObject() { SkipSpaces(); char ch = PeekNextChar(); - if (ch == '}') return Json(std::move(data)); + if (ch == '}') { + GetChar('}'); + return Json(std::move(data)); + } while (true) { SkipSpaces(); diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 94e87bcac..c4adbb42e 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -226,6 +226,10 @@ TEST(Json, EmptyObject) { std::stringstream iss(str); auto json = Json::Load(StringView{str.c_str(), str.size()}); ASSERT_TRUE(IsA(json["statistic"])); + + str = R"json({"Config": {},"Model": {}})json"; // NOLINT + json = Json::Load(StringView{str.c_str(), str.size()}); + ASSERT_TRUE(IsA(json["Model"])); } TEST(Json, EmptyArray) {