From 2a4df8e29f7b004cc2e991e7c0052b003bd3413f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 6 Aug 2019 03:10:49 -0400 Subject: [PATCH] Add Json integer, remove specialization. (#4739) --- include/xgboost/base.h | 6 + include/xgboost/json.h | 138 ++++++++++---------- include/xgboost/json_io.h | 74 ++--------- src/common/json.cc | 229 +++++++++++++++++++++++----------- tests/cpp/common/test_json.cc | 88 +++++++++---- 5 files changed, 314 insertions(+), 221 deletions(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index f1940afba..0922cb22e 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -69,6 +69,12 @@ #define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z)) #endif // GLIBC VERSION +#if defined(__GNUC__) +#define XGBOOST_EXPECT(cond, ret) __builtin_expect((cond), (ret)) +#else +#define XGBOOST_EXPECT(cond, ret) (cond) +#endif // defined(__GNUC__) + /*! * \brief Tag function as usable by device */ diff --git a/include/xgboost/json.h b/include/xgboost/json.h index 46836f1ad..39deca1c2 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -4,8 +4,9 @@ #ifndef XGBOOST_JSON_H_ #define XGBOOST_JSON_H_ -#include +#include +#include #include #include @@ -29,7 +30,6 @@ class Value { Integer, Object, // std::map Array, // std::vector - Raw, Boolean, Null }; @@ -63,9 +63,9 @@ T* Cast(U* value) { if (IsA(value)) { return dynamic_cast(value); } else { - throw std::runtime_error( - "Invalid cast, from " + value->TypeStr() + " to " + T().TypeStr()); + LOG(FATAL) << "Invalid cast, from " + value->TypeStr() + " to " + T().TypeStr(); } + return dynamic_cast(value); // supress compiler warning. } class JsonString : public Value { @@ -123,32 +123,6 @@ class JsonArray : public Value { } }; -class JsonRaw : public Value { - std::string str_; - - public: - explicit JsonRaw(std::string&& str) : - Value(ValueKind::Raw), - str_{std::move(str)}{} // NOLINT - JsonRaw() : Value(ValueKind::Raw) {} - - std::string const& getRaw() && { return str_; } - std::string const& getRaw() const & { return str_; } - std::string& getRaw() & { return str_; } - - void Save(JsonWriter* writer) override; - - Json& operator[](std::string const & key) override; - Json& operator[](int ind) override; - - bool operator==(Value const& rhs) const override; - Value& operator=(Value const& rhs) override; - - static bool isClassOf(Value const* value) { - return value->Type() == ValueKind::Raw; - } -}; - class JsonObject : public Value { std::map object_; @@ -185,7 +159,9 @@ class JsonNumber : public Value { public: JsonNumber() : Value(ValueKind::Number) {} - JsonNumber(double value) : Value(ValueKind::Number) { // NOLINT + template ::value>::type* = nullptr> + JsonNumber(FloatT value) : Value(ValueKind::Number) { // NOLINT number_ = value; } @@ -198,6 +174,7 @@ class JsonNumber : public Value { Float const& getNumber() const & { return number_; } Float& getNumber() & { return number_; } + bool operator==(Value const& rhs) const override; Value& operator=(Value const& rhs) override; @@ -206,6 +183,35 @@ class JsonNumber : public Value { } }; +class JsonInteger : public Value { + public: + using Int = int64_t; + + private: + Int integer_; + + public: + JsonInteger() : Value(ValueKind::Integer), integer_{0} {} // NOLINT + template ::value>::type* = nullptr> + JsonInteger(IntT value) : Value(ValueKind::Integer), integer_{value} {} // NOLINT + + Json& operator[](std::string const & key) override; + Json& operator[](int ind) override; + + bool operator==(Value const& rhs) const override; + Value& operator=(Value const& rhs) override; + + Int const& getInteger() && { return integer_; } + Int const& getInteger() const & { return integer_; } + Int& getInteger() & { return integer_; } + void Save(JsonWriter* writer) override; + + static bool isClassOf(Value const* value) { + return value->Type() == ValueKind::Integer; + } +}; + class JsonNull : public Value { public: JsonNull() : Value(ValueKind::Null) {} @@ -256,15 +262,16 @@ class JsonBoolean : public Value { }; struct StringView { - char const* str_; + using CharT = char; // unsigned char + CharT const* str_; size_t size_; public: StringView() = default; - StringView(char const* str, size_t size) : str_{str}, size_{size} {} + StringView(CharT const* str, size_t size) : str_{str}, size_{size} {} - char const& operator[](size_t p) const { return str_[p]; } - char const& at(size_t p) const { // NOLINT + CharT const& operator[](size_t p) const { return str_[p]; } + CharT const& at(size_t p) const { // NOLINT CHECK_LT(p, size_); return str_[p]; } @@ -302,7 +309,7 @@ class Json { public: /*! \brief Load a Json object from string. */ - static Json Load(StringView str, bool ignore_specialization = false); + static Json Load(StringView str); /*! \brief Pass your own JsonReader. */ static Json Load(JsonReader* reader); /*! \brief Dump json into stream. */ @@ -319,6 +326,13 @@ class Json { return *this; } + // integer + explicit Json(JsonInteger integer) : ptr_{new JsonInteger(integer)} {} + Json& operator=(JsonInteger integer) { + ptr_.reset(new JsonInteger(std::move(integer))); + return *this; + } + // array explicit Json(JsonArray list) : ptr_ {new JsonArray(std::move(list))} {} @@ -327,14 +341,6 @@ class Json { return *this; } - // raw - explicit Json(JsonRaw str) : - ptr_{new JsonRaw(std::move(str))} {} - Json& operator=(JsonRaw str) { - ptr_.reset(new JsonRaw(std::move(str))); - return *this; - } - // object explicit Json(JsonObject object) : ptr_{new JsonObject(std::move(object))} {} @@ -410,10 +416,24 @@ JsonNumber::Float& GetImpl(T& val) { // NOLINT template ::value>::type* = nullptr> -double const& GetImpl(T& val) { // NOLINT +JsonNumber::Float const& GetImpl(T& val) { // NOLINT return val.getNumber(); } +// Integer +template ::value>::type* = nullptr> +JsonInteger::Int& GetImpl(T& val) { // NOLINT + return val.getInteger(); +} +template ::value>::type* = nullptr> +JsonInteger::Int const& GetImpl(T& val) { // NOLINT + return val.getInteger(); +} + // String template ::value>::type* = nullptr> -std::string& GetImpl(T& val) { // NOLINT - return val.getRaw(); -} -template ::value>::type* = nullptr> -std::string const& GetImpl(T& val) { // NOLINT - return val.getRaw(); -} - // Array template decltype(detail::GetImpl(*Cast(&json.GetValue())))& { // using Object = JsonObject; using Array = JsonArray; using Number = JsonNumber; +using Integer = JsonInteger; using Boolean = JsonBoolean; using String = JsonString; using Null = JsonNull; -using Raw = JsonRaw; // Utils tailored for XGBoost. @@ -518,13 +525,14 @@ Object toJson(dmlc::Parameter const& param) { return obj; } -inline std::map fromJson(std::map const& param) { - std::map res; - for (auto const& kv : param) { - res[kv.first] = get(kv.second); +template +void fromJson(Json const& obj, dmlc::Parameter* param) { + auto const& j_param = get(obj); + std::map m; + for (auto const& kv : j_param) { + m[kv.first] = get(kv.second); } - return res; + param->InitAllowUnknown(m); } - } // namespace xgboost #endif // XGBOOST_JSON_H_ diff --git a/include/xgboost/json_io.h b/include/xgboost/json_io.h index d0323354d..deaacf8c6 100644 --- a/include/xgboost/json_io.h +++ b/include/xgboost/json_io.h @@ -22,50 +22,15 @@ class FixedPrecisionStreamContainer : public std::basic_stringstream< public: FixedPrecisionStreamContainer() { this->precision(std::numeric_limits::max_digits10); + this->imbue(std::locale("C")); + this->setf(std::ios::scientific); } }; using FixedPrecisionStream = FixedPrecisionStreamContainer>; /* - * \brief An reader that can be specialised. - * - * Why specialization? - * - * First of all, we don't like specialization. This is purely for performance concern. - * Distributed environment freqently serializes model so at some point this could be a - * bottle neck for training performance. There are many other techniques for obtaining - * better performance, but all of them requires implementing thier own allocaltor(s), - * using simd instructions. And few of them can provide a easy to modify structure - * since they assumes a fixed memory layout. - * - * In XGBoost we provide specialized logic for parsing/writing tree models and linear - * models, where dense numeric values is presented, including weights, node ids etc. - * - * Plan for removing the specialization: - * - * We plan to upstream this implementaion into DMLC as it matures. For XGBoost, most of - * the time spent in load/dump is actually `sprintf`. - * - * To enable specialization, register a keyword that corresponds to - * key in Json object. For example in: - * - * \code - * { "key": {...} } - * \endcode - * - * To add special logic for parsing {...}, one can call: - * - * \code - * JsonReader::registry("key", [](StringView str, size_t* pos){ ... return JsonRaw(...); }); - * \endcode - * - * Where str is a view of entire input string, while pos is a pointer to current position. - * The function must return a raw object. Later after obtaining a parsed object, say - * `Json obj`, you can obtain * the raw object by calling `obj["key"]' then perform the - * specialized parsing on it. - * - * See `LinearSelectRaw` and `LinearReader` in combination as an example. + * \brief A json reader, currently error checking and utf-8 is not fully supported. */ class JsonReader { protected: @@ -77,17 +42,19 @@ class JsonReader { public: SourceLocation() : pos_(0) {} - explicit SourceLocation(size_t pos) : pos_{pos} {} size_t Pos() const { return pos_; } - SourceLocation& Forward(char c = 0) { + SourceLocation& Forward() { pos_++; return *this; } + SourceLocation& Forward(uint32_t n) { + pos_ += n; + return *this; + } } cursor_; StringView raw_str_; - bool ignore_specialization_; protected: void SkipSpaces(); @@ -140,32 +107,13 @@ class JsonReader { Json Parse(); - private: - using Fn = std::function; - public: - explicit JsonReader(StringView str, bool ignore = false) : - raw_str_{str}, - ignore_specialization_{ignore} {} - explicit JsonReader(StringView str, size_t pos, bool ignore = false) : - cursor_{pos}, - raw_str_{str}, - ignore_specialization_{ignore} {} + explicit JsonReader(StringView str) : + raw_str_{str} {} virtual ~JsonReader() = default; Json Load(); - - static std::map& getRegistry() { - static std::map set; - return set; - } - - static std::map const& registry( - std::string const& key, Fn fn) { - getRegistry()[key] = fn; - return getRegistry(); - } }; class JsonWriter { @@ -207,7 +155,7 @@ class JsonWriter { virtual void Visit(JsonArray const* arr); virtual void Visit(JsonObject const* obj); virtual void Visit(JsonNumber const* num); - virtual void Visit(JsonRaw const* raw); + virtual void Visit(JsonInteger const* num); virtual void Visit(JsonNull const* null); virtual void Visit(JsonString const* str); virtual void Visit(JsonBoolean const* boolean); diff --git a/src/common/json.cc b/src/common/json.cc index 917aa1748..396140a06 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -2,11 +2,13 @@ * Copyright (c) by Contributors 2019 */ #include +#include +#include +#include "xgboost/base.h" #include "xgboost/logging.h" #include "xgboost/json.h" #include "xgboost/json_io.h" -#include "../common/timer.h" namespace xgboost { @@ -56,9 +58,11 @@ void JsonWriter::Visit(JsonNumber const* num) { convertor_.str(""); } -void JsonWriter::Visit(JsonRaw const* raw) { - auto const& str = raw->getRaw(); - this->Write(str); +void JsonWriter::Visit(JsonInteger const* num) { + convertor_ << num->getInteger(); + auto const& str = convertor_.str(); + this->Write(StringView{str.c_str(), str.size()}); + convertor_.str(""); } void JsonWriter::Visit(JsonNull const* null) { @@ -120,7 +124,6 @@ std::string Value::TypeStr() const { case ValueKind::Array: return "Array"; break; case ValueKind::Boolean: return "Boolean"; break; case ValueKind::Null: return "Null"; break; - case ValueKind::Raw: return "Raw"; break; case ValueKind::Integer: return "Integer"; break; } return ""; @@ -225,35 +228,6 @@ void JsonArray::Save(JsonWriter* writer) { writer->Visit(this); } -// Json raw -Json& JsonRaw::operator[](std::string const & key) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by string."; - return DummyJsonObject(); -} - -Json& JsonRaw::operator[](int ind) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by Integer."; - return DummyJsonObject(); -} - -bool JsonRaw::operator==(Value const& rhs) const { - if (!IsA(&rhs)) { return false; } - auto& arr = Cast(&rhs)->getRaw(); - return std::equal(arr.cbegin(), arr.cend(), str_.cbegin()); -} - -Value & JsonRaw::operator=(Value const &rhs) { - auto const* casted = Cast(&rhs); - str_ = casted->getRaw(); - return *this; -} - -void JsonRaw::Save(JsonWriter* writer) { - writer->Visit(this); -} - // Json Number Json& JsonNumber::operator[](std::string const & key) { LOG(FATAL) << "Object of type " @@ -282,6 +256,34 @@ void JsonNumber::Save(JsonWriter* writer) { writer->Visit(this); } +// Json Integer +Json& JsonInteger::operator[](std::string const& key) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by string."; + return DummyJsonObject(); +} + +Json& JsonInteger::operator[](int ind) { + LOG(FATAL) << "Object of type " + << Value::TypeStr() << " can not be indexed by Integer."; + return DummyJsonObject(); +} + +bool JsonInteger::operator==(Value const& rhs) const { + if (!IsA(&rhs)) { return false; } + return integer_ == Cast(&rhs)->getInteger(); +} + +Value & JsonInteger::operator=(Value const &rhs) { + JsonInteger const* casted = Cast(&rhs); + integer_ = casted->getInteger(); + return *this; +} + +void JsonInteger::Save(JsonWriter* writer) { + writer->Visit(this); +} + // Json Null Json& JsonNull::operator[](std::string const & key) { LOG(FATAL) << "Object of type " @@ -377,7 +379,8 @@ void JsonReader::Error(std::string msg) const { msg += '\n'; constexpr size_t kExtend = 8; - auto beg = cursor_.Pos() - kExtend < 0 ? 0 : cursor_.Pos() - kExtend; + auto beg = static_cast(cursor_.Pos()) - + static_cast(kExtend) < 0 ? 0 : cursor_.Pos() - kExtend; auto end = cursor_.Pos() + kExtend >= raw_str_.size() ? raw_str_.size() : cursor_.Pos() + kExtend; @@ -401,7 +404,7 @@ void JsonReader::SkipSpaces() { while (cursor_.Pos() < raw_str_.size()) { char c = raw_str_[cursor_.Pos()]; if (std::isspace(c)) { - cursor_.Forward(c); + cursor_.Forward(); } else { break; } @@ -493,6 +496,8 @@ Json JsonReader::ParseObject() { while (true) { SkipSpaces(); ch = PeekNextChar(); + CHECK_NE(ch, -1) << "cursor_.Pos(): " << cursor_.Pos() << ", " + << "raw_str_.size():" << raw_str_.size(); if (ch != '"') { Expect('"', ch); } @@ -504,16 +509,9 @@ Json JsonReader::ParseObject() { Expect(':', ch); } - Json value; - if (!ignore_specialization_ && - (getRegistry().find(get(key)) != getRegistry().cend())) { - LOG(DEBUG) << "Using specialized parser for: " << get(key); - value = getRegistry().at(get(key))(raw_str_, &(cursor_.pos_)); - } else { - value = Parse(); - } + Json value { Parse() }; - data[get(key)] = std::move(value); + data[get(key)] = std::move(value); ch = GetNextNonSpaceChar(); @@ -527,15 +525,118 @@ Json JsonReader::ParseObject() { } Json JsonReader::ParseNumber() { - std::string substr = raw_str_.substr(cursor_.Pos(), kMaxNumLength); - size_t pos = 0; + // Adopted from sajson with some simplifications and small optimizations. + char const* p = raw_str_.c_str() + cursor_.Pos(); + char const* const beg = p; // keep track of current pointer - Number::Float number{0}; - number = std::stof(substr, &pos); - for (size_t i = 0; i < pos; ++i) { - GetNextChar(); + // TODO(trivialfis): Add back all the checks for number + bool negative = false; + if ('-' == *p) { + ++p; + negative = true; + } + + bool is_float = false; + + using ExpInt = std::remove_const< + decltype(std::numeric_limits::max_exponent)>::type; + constexpr auto kExpMax = std::numeric_limits::max(); + constexpr auto kExpMin = std::numeric_limits::min(); + + JsonInteger::Int i = 0; + double f = 0.0; // Use double to maintain accuracy + + if (*p == '0') { + ++p; + } else { + char c = *p; + do { + ++p; + char digit = c - '0'; + i = 10 * i + digit; + c = *p; + } while (std::isdigit(c)); + } + + ExpInt exponent = 0; + const char *const dot_position = p; + if ('.' == *p) { + is_float = true; + f = i; + ++p; + char c = *p; + + do { + ++p; + f = f * 10 + (c - '0'); + c = *p; + } while (std::isdigit(c)); + } + if (is_float) { + exponent = dot_position - p + 1; + } + + char e = *p; + if ('e' == e || 'E' == e) { + if (!is_float) { + is_float = true; + f = i; + } + ++p; + + bool negative_exponent = false; + if ('-' == *p) { + negative_exponent = true; + ++p; + } else if ('+' == *p) { + ++p; + } + + ExpInt exp = 0; + + char c = *p; + while (std::isdigit(c)) { + unsigned char digit = c - '0'; + if (XGBOOST_EXPECT(exp > (kExpMax - digit) / 10, false)) { + CHECK_GT(exp, (kExpMax - digit) / 10) << "Overflow"; + } + exp = 10 * exp + digit; + ++p; + c = *p; + } + static_assert(-kExpMax >= kExpMin, "exp can be negated without loss or UB"); + exponent += (negative_exponent ? -exp : exp); + } + + if (exponent) { + CHECK(is_float); + // If d is zero but the exponent is huge, don't + // multiply zero by inf which gives nan. + if (f != 0.0) { + // Only use exp10 from libc on gcc+linux +#if !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) +#define exp10(val) std::pow(10, (val)) +#endif // !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) + f *= exp10(exponent); +#if !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) +#undef exp10 +#endif // !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) + } + } + + if (negative) { + f = -f; + i = -i; + } + + auto moved = std::distance(beg, p); + this->cursor_.Forward(moved); + + if (is_float) { + return Json(static_cast(f)); + } else { + return Json(JsonInteger(i)); } - return Json(number); } Json JsonReader::ParseBoolean() { @@ -566,7 +667,7 @@ Json JsonReader::ParseBoolean() { } // This is an ad-hoc solution for writing numeric value in standard way. We need to add -// something locale independent way of writing stream. +// a locale independent way of writing stream like `std::{from, to}_chars' from C++-17. // FIXME(trivialfis): Remove this. class GlobalCLocale { std::locale ori_; @@ -585,39 +686,23 @@ class GlobalCLocale { } }; -Json Json::Load(StringView str, bool ignore_specialization) { +Json Json::Load(StringView str) { GlobalCLocale guard; - LOG(WARNING) << "Json serialization is still experimental." - " Output schema is subject to change in the future."; - JsonReader reader(str, ignore_specialization); - common::Timer t; - t.Start(); + JsonReader reader(str); Json json{reader.Load()}; - t.Stop(); - t.PrintElapsed("Json::load"); return json; } Json Json::Load(JsonReader* reader) { GlobalCLocale guard; - common::Timer t; - t.Start(); Json json{reader->Load()}; - t.Stop(); - t.PrintElapsed("Json::load"); return json; } void Json::Dump(Json json, std::ostream *stream, bool pretty) { GlobalCLocale guard; - LOG(WARNING) << "Json serialization is still experimental." - " Output schema is subject to change in the future."; - JsonWriter writer(stream, true); - common::Timer t; - t.Start(); + JsonWriter writer(stream, pretty); writer.Save(json); - t.Stop(); - t.PrintElapsed("Json::dump"); } Json& Json::operator=(Json const &other) = default; diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index a599c3f19..1b37ba885 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -143,9 +143,26 @@ TEST(Json, TestParseObject) { } TEST(Json, ParseNumber) { - std::string str = "31.8892"; - auto json = Json::Load(StringView{str.c_str(), str.size()}); - ASSERT_NEAR(get(json), 31.8892f, kRtEps); + { + std::string str = "31.8892"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); + ASSERT_NEAR(get(json), 31.8892f, kRtEps); + } + { + std::string str = "-31.8892"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); + ASSERT_NEAR(get(json), -31.8892f, kRtEps); + } + { + std::string str = "2e4"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); + ASSERT_NEAR(get(json), 2e4f, kRtEps); + } + { + std::string str = "2e-4"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); + ASSERT_NEAR(get(json), 2e-4f, kRtEps); + } } TEST(Json, ParseArray) { @@ -176,12 +193,13 @@ TEST(Json, ParseArray) { ] } )json"; - auto json = Json::Load(StringView{str.c_str(), str.size()}, true); + auto json = Json::Load(StringView{str.c_str(), str.size()}); json = json["nodes"]; std::vector arr = get(json); ASSERT_EQ(arr.size(), 3); Json v0 = arr[0]; - ASSERT_EQ(get(v0["depth"]), 3); + ASSERT_EQ(get(v0["depth"]), 3); + ASSERT_NEAR(get(v0["gain"]), 10.4866, kRtEps); } TEST(Json, Null) { @@ -203,7 +221,7 @@ TEST(Json, EmptyArray) { } )json"; std::istringstream iss(str); - auto json = Json::Load(StringView{str.c_str(), str.size()}, true); + auto json = Json::Load(StringView{str.c_str(), str.size()}); auto arr = get(json["leaf_vector"]); ASSERT_EQ(arr.size(), 0); } @@ -215,14 +233,14 @@ TEST(Json, Boolean) { "right_child": false } )json"; - Json j {Json::Load(StringView{str.c_str(), str.size()}, true)}; + Json j {Json::Load(StringView{str.c_str(), str.size()})}; ASSERT_EQ(get(j["left_child"]), true); ASSERT_EQ(get(j["right_child"]), false); } TEST(Json, Indexing) { auto str = GetModelStr(); - JsonReader reader(StringView{str.c_str(), str.size()}, true); + JsonReader reader(StringView{str.c_str(), str.size()}); Json j {Json::Load(&reader)}; auto& value_1 = j["model_parameter"]; auto& value = value_1["base_score"]; @@ -242,7 +260,7 @@ TEST(Json, AssigningObjects) { { std::map objects; Json json_objects { JsonObject() }; - std::vector arr_0 (1, Json(3.3)); + std::vector arr_0 (1, Json(3.3f)); json_objects["tree_parameters"] = JsonArray(arr_0); std::vector json_arr = get(json_objects["tree_parameters"]); ASSERT_NEAR(get(json_arr[0]), 3.3f, kRtEps); @@ -263,9 +281,9 @@ TEST(Json, AssigningObjects) { TEST(Json, AssigningArray) { Json json; json = JsonArray(); - std::vector tmp_0 {Json(Number(1)), Json(Number(2))}; + std::vector tmp_0 {Json(Number(1.0f)), Json(Number(2.0f))}; json = tmp_0; - std::vector tmp_1 {Json(Number(3))}; + std::vector tmp_1 {Json(Number(3.0f))}; get(json) = tmp_1; std::vector res = get(json); ASSERT_EQ(get(res[0]), 3); @@ -274,14 +292,14 @@ TEST(Json, AssigningArray) { TEST(Json, AssigningNumber) { { // right value - Json json = Json{ Number(4) }; + Json json = Json{ Number(4.0f) }; get(json) = 15; ASSERT_EQ(get(json), 15); } { // left value ref - Json json = Json{ Number(4) }; + Json json = Json{ Number(4.0f) }; Number::Float& ref = get(json); ref = 15; ASSERT_EQ(get(json), 15); @@ -289,7 +307,7 @@ TEST(Json, AssigningNumber) { { // left value - Json json = Json{ Number(4) }; + Json json = Json{ Number(4.0f) }; double value = get(json); ASSERT_EQ(value, 4); value = 15; // NOLINT @@ -323,8 +341,8 @@ TEST(Json, AssigningString) { } TEST(Json, LoadDump) { - std::string buffer = GetModelStr(); - Json origin {Json::Load(StringView{buffer.c_str(), buffer.size()}, true)}; + std::string ori_buffer = GetModelStr(); + Json origin {Json::Load(StringView{ori_buffer.c_str(), ori_buffer.size()})}; dmlc::TemporaryDirectory tempdir; auto const& path = tempdir.path + "test_model_dump"; @@ -333,10 +351,11 @@ TEST(Json, LoadDump) { Json::Dump(origin, &fout); fout.close(); - buffer = common::LoadSequentialFile(path); - Json load_back {Json::Load(StringView(buffer.c_str(), buffer.size()), true)}; + std::string new_buffer = common::LoadSequentialFile(path); + Json load_back {Json::Load(StringView(new_buffer.c_str(), new_buffer.size()))}; - ASSERT_EQ(load_back, origin); + ASSERT_EQ(load_back, origin) << ori_buffer << "\n\n---------------\n\n" + << new_buffer; } // For now Json is quite ignorance about unicode. @@ -344,7 +363,7 @@ TEST(Json, CopyUnicode) { std::string json_str = R"json( {"m": ["\ud834\udd1e", "\u20ac", "\u0416", "\u00f6"]} )json"; - Json loaded {Json::Load(StringView{json_str.c_str(), json_str.size()}, true)}; + Json loaded {Json::Load(StringView{json_str.c_str(), json_str.size()})}; std::stringstream ss_1; Json::Dump(loaded, &ss_1); @@ -359,7 +378,7 @@ TEST(Json, WrongCasts) { ASSERT_ANY_THROW(get(json)); } { - Json json = Json{ Array{ std::vector{ Json{ Number{1} } } } }; + Json json = Json{ Array{ std::vector{ Json{ Number{1.0f} } } } }; ASSERT_ANY_THROW(get(json)); } { @@ -368,4 +387,31 @@ TEST(Json, WrongCasts) { ASSERT_ANY_THROW(get(json)); } } + +TEST(Json, Int_vs_Float) { + // If integer is parsed as float, calling `get()' will throw. + { + std::string str = R"json( +{ + "number": 123.4, + "integer": 123 +})json"; + + Json obj = Json::Load({str.c_str(), str.size()}); + JsonNumber::Float number = get(obj["number"]); + ASSERT_NEAR(number, 123.4f, kRtEps); + JsonInteger::Int integer = get(obj["integer"]); + ASSERT_EQ(integer, 123); + } + + { + std::string str = R"json( +{"data": [2503595760, false], "shape": [10]} +)json"; + Json obj = Json::Load({str.c_str(), str.size()}); + auto array = get(obj["data"]); + auto ptr = get(array[0]); + ASSERT_EQ(ptr, 2503595760); + } +} } // namespace xgboost