From c635d4c46aa33f19f42c47d08c9630dd286fc43f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 10 Jan 2022 23:24:23 +0800 Subject: [PATCH] Implement ubjson. (#7549) * Implement ubjson. This is a partial implementation of UBJSON with support for typed arrays. Some missing features are `f64`, typed object, and the no-op. --- include/xgboost/intrusive_ptr.h | 4 +- include/xgboost/json.h | 167 +++++++--- include/xgboost/json_io.h | 153 ++++++++- src/common/json.cc | 569 ++++++++++++++++++++++---------- tests/cpp/common/test_json.cc | 108 +++++- tests/cpp/test_serialization.cc | 33 +- 6 files changed, 792 insertions(+), 242 deletions(-) diff --git a/include/xgboost/intrusive_ptr.h b/include/xgboost/intrusive_ptr.h index df7ae3021..a0b860be5 100644 --- a/include/xgboost/intrusive_ptr.h +++ b/include/xgboost/intrusive_ptr.h @@ -127,8 +127,8 @@ template class IntrusivePtr { ptr_ = nullptr; } void reset(element_type *that) { IntrusivePtr{that}.swap(*this); } // NOLINT - - element_type &operator*() const noexcept { return *ptr_; } + // clang-tidy might manufacture a null value, disable the check + element_type &operator*() const noexcept { return *ptr_; } // NOLINT element_type *operator->() const noexcept { return ptr_; } element_type *get() const noexcept { return ptr_; } // NOLINT diff --git a/include/xgboost/json.h b/include/xgboost/json.h index 62f5ad8f6..885a0d1cd 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) by XGBoost Contributors 2019-2021 + * Copyright (c) by XGBoost Contributors 2019-2022 */ #ifndef XGBOOST_JSON_H_ #define XGBOOST_JSON_H_ @@ -39,7 +39,12 @@ class Value { kObject, // std::map kArray, // std::vector kBoolean, - kNull + kNull, + // typed array for ubjson + kNumberArray, + kU8Array, + kI32Array, + kI64Array }; explicit Value(ValueKind _kind) : kind_{_kind} {} @@ -47,13 +52,13 @@ class Value { ValueKind Type() const { return kind_; } virtual ~Value() = default; - virtual void Save(JsonWriter* writer) = 0; + virtual void Save(JsonWriter* writer) const = 0; - virtual Json& operator[](std::string const & key) = 0; - virtual Json& operator[](int ind) = 0; + virtual Json& operator[](std::string const& key); + virtual Json& operator[](int ind); virtual bool operator==(Value const& rhs) const = 0; - virtual Value& operator=(Value const& rhs) = 0; + virtual Value& operator=(Value const& rhs) = delete; std::string TypeStr() const; @@ -88,17 +93,13 @@ class JsonString : public Value { JsonString(JsonString&& str) noexcept : // NOLINT Value(ValueKind::kString), str_{std::move(str.str_)} {} - void Save(JsonWriter* writer) override; - - Json& operator[](std::string const & key) override; - Json& operator[](int ind) override; + void Save(JsonWriter* writer) const override; std::string const& GetString() && { return str_; } std::string const& GetString() const & { return str_; } std::string& GetString() & { return str_; } bool operator==(Value const& rhs) const override; - Value& operator=(Value const& rhs) override; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kString; @@ -117,23 +118,71 @@ class JsonArray : public Value { JsonArray(JsonArray const& that) = delete; JsonArray(JsonArray && that) noexcept; - void Save(JsonWriter* writer) override; + void Save(JsonWriter* writer) const override; - Json& operator[](std::string const & key) override; - Json& operator[](int ind) override; + Json& operator[](int ind) override { return vec_.at(ind); } + // silent the partial oveeridden warning + Json& operator[](std::string const& key) override { return Value::operator[](key); } std::vector const& GetArray() && { return vec_; } std::vector const& GetArray() const & { return vec_; } std::vector& GetArray() & { return vec_; } bool operator==(Value const& rhs) const override; - Value& operator=(Value const& rhs) override; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kArray; } }; +/** + * \brief Typed array for Universal Binary JSON. + * + * \tparam T The underlying primitive type. + * \tparam kind Value kind defined by JSON type. + */ +template +class JsonTypedArray : public Value { + std::vector vec_; + + public: + using Type = T; + + JsonTypedArray() : Value(kind) {} + explicit JsonTypedArray(size_t n) : Value(kind) { vec_.resize(n); } + JsonTypedArray(JsonTypedArray&& that) noexcept : Value{kind}, vec_{std::move(that.vec_)} {} + + bool operator==(Value const& rhs) const override; + + void Set(size_t i, T v) { vec_[i] = v; } + size_t Size() const { return vec_.size(); } + + void Save(JsonWriter* writer) const override; + + std::vector const& GetArray() && { return vec_; } + std::vector const& GetArray() const& { return vec_; } + std::vector& GetArray() & { return vec_; } + + static bool IsClassOf(Value const* value) { return value->Type() == kind; } +}; + +/** + * \brief Typed UBJSON array for 32-bit floating point. + */ +using F32Array = JsonTypedArray; +/** + * \brief Typed UBJSON array for uint8_t. + */ +using U8Array = JsonTypedArray; +/** + * \brief Typed UBJSON array for int32_t. + */ +using I32Array = JsonTypedArray; +/** + * \brief Typed UBJSON array for int64_t. + */ +using I64Array = JsonTypedArray; + class JsonObject : public Value { std::map object_; @@ -143,17 +192,17 @@ class JsonObject : public Value { JsonObject(JsonObject const& that) = delete; JsonObject(JsonObject && that) noexcept; - void Save(JsonWriter* writer) override; + void Save(JsonWriter* writer) const override; - Json& operator[](std::string const & key) override; - Json& operator[](int ind) override; + // silent the partial oveeridden warning + Json& operator[](int ind) override { return Value::operator[](ind); } + Json& operator[](std::string const& key) override { return object_[key]; } std::map const& GetObject() && { return object_; } std::map const& GetObject() const & { return object_; } std::map & GetObject() & { return object_; } bool operator==(Value const& rhs) const override; - Value& operator=(Value const& rhs) override; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kObject; @@ -182,18 +231,13 @@ class JsonNumber : public Value { JsonNumber(JsonNumber const& that) = delete; JsonNumber(JsonNumber&& that) noexcept : Value{ValueKind::kNumber}, number_{that.number_} {} - void Save(JsonWriter* writer) override; - - Json& operator[](std::string const & key) override; - Json& operator[](int ind) override; + void Save(JsonWriter* writer) const override; Float const& GetNumber() && { return number_; } Float const& GetNumber() const & { return number_; } Float& GetNumber() & { return number_; } - bool operator==(Value const& rhs) const override; - Value& operator=(Value const& rhs) override; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kNumber; @@ -231,16 +275,12 @@ class JsonInteger : public Value { JsonInteger(JsonInteger &&that) noexcept : Value{ValueKind::kInteger}, integer_{that.integer_} {} - Json& operator[](std::string const & key) override; - Json& operator[](int ind) override; - bool operator==(Value const& rhs) const override; - Value& operator=(Value const& rhs) override; Int const& GetInteger() && { return integer_; } Int const& GetInteger() const & { return integer_; } Int& GetInteger() & { return integer_; } - void Save(JsonWriter* writer) override; + void Save(JsonWriter* writer) const override; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kInteger; @@ -253,13 +293,9 @@ class JsonNull : public Value { JsonNull(std::nullptr_t) : Value(ValueKind::kNull) {} // NOLINT JsonNull(JsonNull&&) noexcept : Value(ValueKind::kNull) {} - void Save(JsonWriter* writer) override; - - Json& operator[](std::string const & key) override; - Json& operator[](int ind) override; + void Save(JsonWriter* writer) const override; bool operator==(Value const& rhs) const override; - Value& operator=(Value const& rhs) override; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kNull; @@ -282,17 +318,13 @@ class JsonBoolean : public Value { JsonBoolean(JsonBoolean&& value) noexcept: // NOLINT Value(ValueKind::kBoolean), boolean_{value.boolean_} {} - void Save(JsonWriter* writer) override; - - Json& operator[](std::string const & key) override; - Json& operator[](int ind) override; + void Save(JsonWriter* writer) const override; bool const& GetBoolean() && { return boolean_; } bool const& GetBoolean() const & { return boolean_; } bool& GetBoolean() & { return boolean_; } bool operator==(Value const& rhs) const override; - Value& operator=(Value const& rhs) override; static bool IsClassOf(Value const* value) { return value->Type() == ValueKind::kBoolean; @@ -317,14 +349,22 @@ class JsonBoolean : public Value { * \endcode */ class Json { - friend JsonWriter; - public: - /*! \brief Load a Json object from string. */ - static Json Load(StringView str); + /** + * \brief Decode the JSON object. Optional parameter mode for choosing between text + * and binary (ubjson) input. + */ + static Json Load(StringView str, std::ios::openmode mode = std::ios::in); /*! \brief Pass your own JsonReader. */ static Json Load(JsonReader* reader); - static void Dump(Json json, std::string* out); + /** + * \brief Encode the JSON object. Optional parameter mode for choosing between text + * and binary (ubjson) output. + */ + static void Dump(Json json, std::string* out, std::ios::openmode mode = std::ios::out); + static void Dump(Json json, std::vector* out, std::ios::openmode mode = std::ios::out); + /*! \brief Use your own JsonWriter. */ + static void Dump(Json json, JsonWriter* writer); Json() : ptr_{new JsonNull} {} @@ -334,14 +374,12 @@ class Json { ptr_.reset(new JsonNumber(std::move(number))); return *this; } - // integer explicit Json(JsonInteger integer) : ptr_{new JsonInteger(std::move(integer))} {} Json& operator=(JsonInteger integer) { ptr_.reset(new JsonInteger(std::move(integer))); return *this; } - // array explicit Json(JsonArray list) : ptr_ {new JsonArray(std::move(list))} {} @@ -349,7 +387,15 @@ class Json { ptr_.reset(new JsonArray(std::move(array))); return *this; } - + // typed array + template + explicit Json(JsonTypedArray&& list) + : ptr_{new JsonTypedArray(std::forward>(list))} {} + template + Json& operator=(JsonTypedArray&& array) { + ptr_.reset(new JsonTypedArray(std::forward>(array))); + return *this; + } // object explicit Json(JsonObject object) : ptr_{new JsonObject(std::move(object))} {} @@ -381,7 +427,7 @@ class Json { // copy Json(Json const& other) = default; - Json& operator=(Json const& other); + Json& operator=(Json const& other) = default; // move Json(Json &&other) noexcept { std::swap(this->ptr_, other.ptr_); } Json &operator=(Json &&other) noexcept { @@ -410,10 +456,21 @@ class Json { return os; } + IntrusivePtr const& Ptr() const { return ptr_; } + private: IntrusivePtr ptr_; }; +/** + * \brief Check whether a Json object has specific type. + * + * \code + * Json json {Array{}}; + * bool is_array = IsA(json); + * CHECK(is_array); + * \endcode + */ template bool IsA(Json const& j) { auto const& v = j.GetValue(); @@ -421,7 +478,6 @@ bool IsA(Json const& j) { } namespace detail { - // Number template const& GetImpl(T& val) { // NOLINT return val.GetArray(); } +// Typed Array +template +std::vector& GetImpl(JsonTypedArray& val) { // NOLINT + return val.GetArray(); +} +template +std::vector const& GetImpl(JsonTypedArray const& val) { + return val.GetArray(); +} + // Object template const& GetImpl(T& val) { // NOLINT return val.GetObject(); } - } // namespace detail /*! diff --git a/include/xgboost/json_io.h b/include/xgboost/json_io.h index 2bde544ac..4827c2047 100644 --- a/include/xgboost/json_io.h +++ b/include/xgboost/json_io.h @@ -1,20 +1,20 @@ /*! - * Copyright (c) by Contributors 2019 + * Copyright (c) by Contributors 2019-2022 */ #ifndef XGBOOST_JSON_IO_H_ #define XGBOOST_JSON_IO_H_ -#include +#include #include +#include -#include +#include +#include +#include #include +#include #include #include -#include -#include -#include -#include -#include +#include namespace xgboost { /* @@ -47,7 +47,7 @@ class JsonReader { void SkipSpaces(); char GetNextChar() { - if (cursor_.Pos() == raw_str_.size()) { + if (XGBOOST_EXPECT((cursor_.Pos() == raw_str_.size()), false)) { return -1; } char ch = raw_str_[cursor_.Pos()]; @@ -109,12 +109,30 @@ class JsonReader { virtual ~JsonReader() = default; - Json Load(); + virtual Json Load(); }; class JsonWriter { - static constexpr size_t kIndentSize = 2; + template ::value>* = nullptr> + void Save(T const& v) { + this->Save(Json{v}); + } + template + void WriteArray(Array const* arr, Fn&& fn) { + stream_->emplace_back('['); + auto const& vec = arr->GetArray(); + size_t size = vec.size(); + for (size_t i = 0; i < size; ++i) { + auto const& value = vec[i]; + this->Save(fn(value)); + if (i != size - 1) { + stream_->emplace_back(','); + } + } + stream_->emplace_back(']'); + } + protected: std::vector* stream_; public: @@ -122,9 +140,13 @@ class JsonWriter { virtual ~JsonWriter() = default; - void Save(Json json); + virtual void Save(Json json); virtual void Visit(JsonArray const* arr); + virtual void Visit(F32Array const* arr); + virtual void Visit(U8Array const* arr); + virtual void Visit(I32Array const* arr); + virtual void Visit(I64Array const* arr); virtual void Visit(JsonObject const* obj); virtual void Visit(JsonNumber const* num); virtual void Visit(JsonInteger const* num); @@ -132,6 +154,113 @@ class JsonWriter { virtual void Visit(JsonString const* str); virtual void Visit(JsonBoolean const* boolean); }; + +#if defined(__GLIBC__) +template +T BuiltinBSwap(T v); + +template <> +inline uint16_t BuiltinBSwap(uint16_t v) { + return __builtin_bswap16(v); +} + +template <> +inline uint32_t BuiltinBSwap(uint32_t v) { + return __builtin_bswap32(v); +} + +template <> +inline uint64_t BuiltinBSwap(uint64_t v) { + return __builtin_bswap64(v); +} +#else +template +T BuiltinBSwap(T v) { + dmlc::ByteSwap(&v, sizeof(v), 1); + return v; +} +#endif // defined(__GLIBC__) + +template * = nullptr> +inline T ByteSwap(T v) { + return v; +} + +template * = nullptr> +inline T ByteSwap(T v) { + static_assert(std::is_pod::value, "Only pod is supported."); +#if DMLC_LITTLE_ENDIAN + auto constexpr kS = sizeof(T); + std::conditional_t> u; + std::memcpy(&u, &v, sizeof(u)); + u = BuiltinBSwap(u); + std::memcpy(&v, &u, sizeof(u)); +#endif // DMLC_LITTLE_ENDIAN + return v; +} + +/** + * \brief Reader for UBJSON https://ubjson.org/ + */ +class UBJReader : public JsonReader { + Json Parse(); + + template + T ReadStream() { + auto ptr = this->raw_str_.c_str() + cursor_.Pos(); + T v{0}; + std::memcpy(&v, ptr, sizeof(v)); + cursor_.Forward(sizeof(v)); + return v; + } + + template + T ReadPrimitive() { + auto v = ReadStream(); + v = ByteSwap(v); + return v; + } + + template + auto ParseTypedArray(int64_t n) { + TypedArray results{static_cast(n)}; + for (int64_t i = 0; i < n; ++i) { + auto v = this->ReadPrimitive(); + results.Set(i, v); + } + return Json{std::move(results)}; + } + + std::string DecodeStr(); + + Json ParseArray() override; + Json ParseObject() override; + + public: + using JsonReader::JsonReader; + Json Load() override; +}; + +/** + * \brief Writer for UBJSON https://ubjson.org/ + */ +class UBJWriter : public JsonWriter { + void Visit(JsonArray const* arr) override; + void Visit(F32Array const* arr) override; + void Visit(U8Array const* arr) override; + void Visit(I32Array const* arr) override; + void Visit(I64Array const* arr) override; + void Visit(JsonObject const* obj) override; + void Visit(JsonNumber const* num) override; + void Visit(JsonInteger const* num) override; + void Visit(JsonNull const* null) override; + void Visit(JsonString const* str) override; + void Visit(JsonBoolean const* boolean) override; + + public: + using JsonWriter::JsonWriter; + void Save(Json json) override; +}; } // namespace xgboost #endif // XGBOOST_JSON_IO_H_ diff --git a/src/common/json.cc b/src/common/json.cc index 924f1cecd..83ef27182 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -1,16 +1,18 @@ /*! - * Copyright (c) by Contributors 2019-2021 + * Copyright (c) by Contributors 2019-2022 */ #include "xgboost/json.h" +#include + #include #include #include #include #include -#include #include +#include "./math.h" #include "charconv.h" #include "xgboost/base.h" #include "xgboost/json_io.h" @@ -19,23 +21,20 @@ namespace xgboost { -void JsonWriter::Save(Json json) { - json.ptr_->Save(this); -} +void JsonWriter::Save(Json json) { json.Ptr()->Save(this); } void JsonWriter::Visit(JsonArray const* arr) { - stream_->emplace_back('['); - auto const& vec = arr->GetArray(); - size_t size = vec.size(); - for (size_t i = 0; i < size; ++i) { - auto const& value = vec[i]; - this->Save(value); - if (i != size - 1) { - stream_->emplace_back(','); - } - } - stream_->emplace_back(']'); + this->WriteArray(arr, [](auto const& v) { return v; }); } +void JsonWriter::Visit(F32Array const* arr) { + this->WriteArray(arr, [](float v) { return Json{v}; }); +} +namespace { +auto to_i64 = [](auto v) { return Json{static_cast(v)}; }; +} // anonymous namespace +void JsonWriter::Visit(U8Array const* arr) { this->WriteArray(arr, to_i64); } +void JsonWriter::Visit(I32Array const* arr) { this->WriteArray(arr, to_i64); } +void JsonWriter::Visit(I64Array const* arr) { this->WriteArray(arr, to_i64); } void JsonWriter::Visit(JsonObject const* obj) { stream_->emplace_back('{'); @@ -152,13 +151,28 @@ void JsonWriter::Visit(JsonBoolean const* boolean) { // Value std::string Value::TypeStr() const { switch (kind_) { - case ValueKind::kString: return "String"; break; - case ValueKind::kNumber: return "Number"; break; - case ValueKind::kObject: return "Object"; break; - case ValueKind::kArray: return "Array"; break; - case ValueKind::kBoolean: return "Boolean"; break; - case ValueKind::kNull: return "Null"; break; - case ValueKind::kInteger: return "Integer"; break; + case ValueKind::kString: + return "String"; + case ValueKind::kNumber: + return "Number"; + case ValueKind::kObject: + return "Object"; + case ValueKind::kArray: + return "Array"; + case ValueKind::kBoolean: + return "Boolean"; + case ValueKind::kNull: + return "Null"; + case ValueKind::kInteger: + return "Integer"; + case ValueKind::kNumberArray: + return "F32Array"; + case ValueKind::kU8Array: + return "U8Array"; + case ValueKind::kI32Array: + return "I32Array"; + case ValueKind::kI64Array: + return "I64Array"; } return ""; } @@ -170,6 +184,16 @@ Json& DummyJsonObject() { return obj; } +Json& Value::operator[](std::string const&) { + LOG(FATAL) << "Object of type " << TypeStr() << " can not be indexed by string."; + return DummyJsonObject(); +} + +Json& Value::operator[](int) { + LOG(FATAL) << "Object of type " << TypeStr() << " can not be indexed by Integer."; + return DummyJsonObject(); +} + // Json Object JsonObject::JsonObject(JsonObject && that) noexcept : Value(ValueKind::kObject), object_{std::move(that.object_)} {} @@ -177,16 +201,6 @@ JsonObject::JsonObject(JsonObject && that) noexcept : JsonObject::JsonObject(std::map &&object) noexcept : Value(ValueKind::kObject), object_{std::move(object)} {} -Json& JsonObject::operator[](std::string const & key) { - return object_[key]; -} - -Json& JsonObject::operator[](int ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by Integer."; - return DummyJsonObject(); -} - bool JsonObject::operator==(Value const& rhs) const { if (!IsA(&rhs)) { return false; @@ -194,89 +208,86 @@ bool JsonObject::operator==(Value const& rhs) const { return object_ == Cast(&rhs)->GetObject(); } -Value& JsonObject::operator=(Value const &rhs) { - JsonObject const* casted = Cast(&rhs); - object_ = casted->GetObject(); - return *this; -} - -void JsonObject::Save(JsonWriter* writer) { - writer->Visit(this); -} +void JsonObject::Save(JsonWriter* writer) const { writer->Visit(this); } // Json String -Json& JsonString::operator[](std::string const& ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by string."; - return DummyJsonObject(); -} - -Json& JsonString::operator[](int ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by Integer." - << " Please try obtaining std::string first."; - return DummyJsonObject(); -} - bool JsonString::operator==(Value const& rhs) const { if (!IsA(&rhs)) { return false; } return Cast(&rhs)->GetString() == str_; } -Value & JsonString::operator=(Value const &rhs) { - JsonString const* casted = Cast(&rhs); - str_ = casted->GetString(); - return *this; -} - // FIXME: UTF-8 parsing support. -void JsonString::Save(JsonWriter* writer) { - writer->Visit(this); -} +void JsonString::Save(JsonWriter* writer) const { writer->Visit(this); } // Json Array JsonArray::JsonArray(JsonArray && that) noexcept : Value(ValueKind::kArray), vec_{std::move(that.vec_)} {} -Json& JsonArray::operator[](std::string const& ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by string."; - return DummyJsonObject(); -} - -Json& JsonArray::operator[](int ind) { - return vec_.at(ind); -} - bool JsonArray::operator==(Value const& rhs) const { - if (!IsA(&rhs)) { return false; } + if (!IsA(&rhs)) { + return false; + } auto& arr = Cast(&rhs)->GetArray(); + if (vec_.size() != arr.size()) { + return false; + } return std::equal(arr.cbegin(), arr.cend(), vec_.cbegin()); } -Value & JsonArray::operator=(Value const &rhs) { - JsonArray const* casted = Cast(&rhs); - vec_ = casted->GetArray(); - return *this; -} +void JsonArray::Save(JsonWriter* writer) const { writer->Visit(this); } -void JsonArray::Save(JsonWriter* writer) { +// typed array +namespace { +// error C2668: 'fpclassify': ambiguous call to overloaded function +template +std::enable_if_t::value, bool> IsInfMSVCWar(T v) { + return std::isinf(v); +} +template +std::enable_if_t::value, bool> IsInfMSVCWar(T v) { + return false; +} +} // namespace + +template +void JsonTypedArray::Save(JsonWriter* writer) const { writer->Visit(this); } +template +bool JsonTypedArray::operator==(Value const& rhs) const { + if (!IsA>(&rhs)) { + return false; + } + auto& arr = Cast const>(&rhs)->GetArray(); + if (vec_.size() != arr.size()) { + return false; + } + if (std::is_same::value) { + for (size_t i = 0; i < vec_.size(); ++i) { + bool equal{false}; + if (common::CheckNAN(vec_[i])) { + equal = common::CheckNAN(arr[i]); + } else if (IsInfMSVCWar(vec_[i])) { + equal = IsInfMSVCWar(arr[i]); + } else { + equal = (arr[i] - vec_[i] == 0); + } + if (!equal) { + return false; + } + } + return true; + } + return std::equal(arr.cbegin(), arr.cend(), vec_.cbegin()); +} + +template class JsonTypedArray; +template class JsonTypedArray; +template class JsonTypedArray; +template class JsonTypedArray; + // Json Number -Json& JsonNumber::operator[](std::string const& ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by string."; - return DummyJsonObject(); -} - -Json& JsonNumber::operator[](int ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by Integer."; - return DummyJsonObject(); -} - bool JsonNumber::operator==(Value const& rhs) const { if (!IsA(&rhs)) { return false; } auto r_num = Cast(&rhs)->GetNumber(); @@ -289,98 +300,31 @@ bool JsonNumber::operator==(Value const& rhs) const { return number_ - r_num == 0; } -Value & JsonNumber::operator=(Value const &rhs) { - JsonNumber const* casted = Cast(&rhs); - number_ = casted->GetNumber(); - return *this; -} - -void JsonNumber::Save(JsonWriter* writer) { - writer->Visit(this); -} +void JsonNumber::Save(JsonWriter* writer) const { writer->Visit(this); } // Json Integer -Json& JsonInteger::operator[](std::string const& ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by string."; - return DummyJsonObject(); -} - -Json& JsonInteger::operator[](int ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by Integer."; - return DummyJsonObject(); -} - bool JsonInteger::operator==(Value const& rhs) const { if (!IsA(&rhs)) { return false; } return integer_ == Cast(&rhs)->GetInteger(); } -Value & JsonInteger::operator=(Value const &rhs) { - JsonInteger const* casted = Cast(&rhs); - integer_ = casted->GetInteger(); - return *this; -} - -void JsonInteger::Save(JsonWriter* writer) { - writer->Visit(this); -} +void JsonInteger::Save(JsonWriter* writer) const { writer->Visit(this); } // Json Null -Json& JsonNull::operator[](std::string const& ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by string."; - return DummyJsonObject(); -} - -Json& JsonNull::operator[](int ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by Integer."; - return DummyJsonObject(); -} - bool JsonNull::operator==(Value const& rhs) const { if (!IsA(&rhs)) { return false; } return true; } -Value & JsonNull::operator=(Value const &rhs) { - Cast(&rhs); // Checking only. - return *this; -} - -void JsonNull::Save(JsonWriter* writer) { - writer->Visit(this); -} +void JsonNull::Save(JsonWriter* writer) const { writer->Visit(this); } // Json Boolean -Json& JsonBoolean::operator[](std::string const& ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by string."; - return DummyJsonObject(); -} - -Json& JsonBoolean::operator[](int ) { - LOG(FATAL) << "Object of type " - << Value::TypeStr() << " can not be indexed by Integer."; - return DummyJsonObject(); -} - bool JsonBoolean::operator==(Value const& rhs) const { if (!IsA(&rhs)) { return false; } return boolean_ == Cast(&rhs)->GetBoolean(); } -Value & JsonBoolean::operator=(Value const &rhs) { - JsonBoolean const* casted = Cast(&rhs); - boolean_ = casted->GetBoolean(); - return *this; -} - -void JsonBoolean::Save(JsonWriter *writer) { - writer->Visit(this); -} +void JsonBoolean::Save(JsonWriter* writer) const { writer->Visit(this); } size_t constexpr JsonReader::kMaxNumLength; @@ -727,9 +671,15 @@ Json JsonReader::ParseBoolean() { return Json{JsonBoolean{result}}; } -Json Json::Load(StringView str) { - JsonReader reader(str); - Json json{reader.Load()}; +Json Json::Load(StringView str, std::ios::openmode mode) { + Json json; + if (mode & std::ios::binary) { + UBJReader reader{str}; + json = Json::Load(&reader); + } else { + JsonReader reader(str); + json = reader.Load(); + } return json; } @@ -738,18 +688,295 @@ Json Json::Load(JsonReader* reader) { return json; } -void Json::Dump(Json json, std::string* str) { +void Json::Dump(Json json, std::string* str, std::ios::openmode mode) { std::vector buffer; - JsonWriter writer(&buffer); - writer.Save(json); + Dump(json, &buffer, mode); str->resize(buffer.size()); std::copy(buffer.cbegin(), buffer.cend(), str->begin()); } -Json& Json::operator=(Json const &other) = default; +void Json::Dump(Json json, std::vector* str, std::ios::openmode mode) { + if (mode & std::ios::binary) { + UBJWriter writer{str}; + writer.Save(json); + } else { + JsonWriter writer(str); + writer.Save(json); + } +} + +void Json::Dump(Json json, JsonWriter* writer) { + writer->Save(json); +} static_assert(std::is_nothrow_move_constructible::value, ""); static_assert(std::is_nothrow_move_constructible::value, ""); static_assert(std::is_nothrow_move_constructible::value, ""); static_assert(std::is_nothrow_move_constructible::value, ""); + +Json UBJReader::ParseArray() { + auto marker = PeekNextChar(); + + if (marker == '$') { // typed array + GetNextChar(); // remove $ + marker = GetNextChar(); + auto type = marker; + GetConsecutiveChar('#'); + GetConsecutiveChar('L'); + auto n = this->ReadPrimitive(); + + marker = PeekNextChar(); + switch (type) { + case 'd': + return ParseTypedArray(n); + case 'U': + return ParseTypedArray(n); + case 'l': + return ParseTypedArray(n); + case 'L': + return ParseTypedArray(n); + default: + LOG(FATAL) << "`" + std::string{type} + "` is not supported for typed array."; // NOLINT + } + } + std::vector results; + if (marker == '#') { // array with length optimization + GetNextChar(); + GetConsecutiveChar('L'); + auto n = this->ReadPrimitive(); + results.resize(n); + for (int64_t i = 0; i < n; ++i) { + results[i] = Parse(); + } + } else { // normal array + while (marker != ']') { + results.emplace_back(Parse()); + marker = PeekNextChar(); + } + GetConsecutiveChar(']'); + } + + return Json{results}; +} + +std::string UBJReader::DecodeStr() { + // only L is supported right now. + GetConsecutiveChar('L'); + auto bsize = this->ReadPrimitive(); + + std::string str; + str.resize(bsize); + auto ptr = raw_str_.c_str() + cursor_.Pos(); + std::memcpy(&str[0], ptr, bsize); + for (int64_t i = 0; i < bsize; ++i) { + this->cursor_.Forward(); + } + return str; +} + +Json UBJReader::ParseObject() { + auto marker = PeekNextChar(); + std::map results; + + while (marker != '}') { + auto str = this->DecodeStr(); + results.emplace(str, this->Parse()); + marker = PeekNextChar(); + } + + GetConsecutiveChar('}'); + return Json{std::move(results)}; +} + +Json UBJReader::Load() { + Json result = Parse(); + return result; +} + +Json UBJReader::Parse() { + while (true) { + char c = PeekNextChar(); + if (c == -1) { + break; + } + + GetNextChar(); + switch (c) { + case '{': + return ParseObject(); + case '[': + return ParseArray(); + case 'Z': { + return Json{nullptr}; + } + case 'T': { + return Json{JsonBoolean{true}}; + } + case 'F': { + return Json{JsonBoolean{true}}; + } + case 'd': { + auto v = this->ReadPrimitive(); + return Json{v}; + } + case 'S': { + auto str = this->DecodeStr(); + return Json{str}; + } + case 'i': { + Integer::Int i = this->ReadPrimitive(); + return Json{i}; + } + case 'U': { + Integer::Int i = this->ReadPrimitive(); + return Json{i}; + } + case 'I': { + Integer::Int i = this->ReadPrimitive(); + return Json{i}; + } + case 'l': { + Integer::Int i = this->ReadPrimitive(); + return Json{i}; + } + case 'L': { + auto i = this->ReadPrimitive(); + return Json{i}; + } + case 'C': { + Integer::Int i = this->ReadPrimitive(); + return Json{i}; + } + case 'D': { + LOG(FATAL) << "f64 is not supported."; + } + case 'H': { + LOG(FATAL) << "High precision number is not supported."; + } + default: + Error("Unknown construct"); + } + } + return {}; +} + +namespace { +template +void WritePrimitive(T v, std::vector* stream) { + v = ByteSwap(v); + auto s = stream->size(); + stream->resize(s + sizeof(v)); + auto ptr = stream->data() + s; + std::memcpy(ptr, &v, sizeof(v)); +} + +void EncodeStr(std::vector* stream, std::string const& string) { + stream->push_back('L'); + + int64_t bsize = string.size(); + WritePrimitive(bsize, stream); + + auto s = stream->size(); + stream->resize(s + string.size()); + + auto ptr = stream->data() + s; + std::memcpy(ptr, string.data(), string.size()); +} +} // anonymous namespace + +void UBJWriter::Visit(JsonArray const* arr) { + stream_->emplace_back('['); + auto const& vec = arr->GetArray(); + int64_t n = vec.size(); + stream_->push_back('#'); + stream_->push_back('L'); + WritePrimitive(n, stream_); + for (auto const& v : vec) { + this->Save(v); + } +} + +template +void WriteTypedArray(JsonTypedArray const* arr, std::vector* stream) { + stream->emplace_back('['); + stream->push_back('$'); + if (std::is_same::value) { + stream->push_back('d'); + } else if (std::is_same::value) { + stream->push_back('i'); + } else if (std::is_same::value) { + stream->push_back('U'); + } else if (std::is_same::value) { + stream->push_back('l'); + } else if (std::is_same::value) { + stream->push_back('L'); + } else { + LOG(FATAL) << "Not implemented"; + } + + stream->push_back('#'); + stream->push_back('L'); + + auto n = arr->Size(); + WritePrimitive(n, stream); + auto s = stream->size(); + stream->resize(s + arr->Size() * sizeof(T)); + auto const& vec = arr->GetArray(); + for (size_t i = 0; i < n; ++i) { + auto v = ByteSwap(vec[i]); + std::memcpy(stream->data() + s, &v, sizeof(v)); + s += sizeof(v); + } +} + +void UBJWriter::Visit(F32Array const* arr) { WriteTypedArray(arr, stream_); } +void UBJWriter::Visit(U8Array const* arr) { WriteTypedArray(arr, stream_); } +void UBJWriter::Visit(I32Array const* arr) { WriteTypedArray(arr, stream_); } +void UBJWriter::Visit(I64Array const* arr) { WriteTypedArray(arr, stream_); } + +void UBJWriter::Visit(JsonObject const* obj) { + stream_->emplace_back('{'); + for (auto const& value : obj->GetObject()) { + auto const& key = value.first; + EncodeStr(stream_, key); + this->Save(value.second); + } + stream_->emplace_back('}'); +} + +void UBJWriter::Visit(JsonNumber const* num) { + stream_->push_back('d'); + auto val = num->GetNumber(); + WritePrimitive(val, stream_); +} + +void UBJWriter::Visit(JsonInteger const* num) { + auto i = num->GetInteger(); + if (i > std::numeric_limits::min() && i < std::numeric_limits::max()) { + stream_->push_back('i'); + WritePrimitive(static_cast(i), stream_); + } else if (i > std::numeric_limits::min() && i < std::numeric_limits::max()) { + stream_->push_back('I'); + WritePrimitive(static_cast(i), stream_); + } else if (i > std::numeric_limits::min() && i < std::numeric_limits::max()) { + stream_->push_back('l'); + WritePrimitive(static_cast(i), stream_); + } else { + stream_->push_back('L'); + WritePrimitive(i, stream_); + } +} + +void UBJWriter::Visit(JsonNull const* null) { stream_->push_back('Z'); } + +void UBJWriter::Visit(JsonString const* str) { + stream_->push_back('S'); + EncodeStr(stream_, str->GetString()); +} + +void UBJWriter::Visit(JsonBoolean const* boolean) { + stream_->push_back(boolean->GetBoolean() ? 'T' : 'F'); +} + +void UBJWriter::Save(Json json) { json.Ptr()->Save(this); } } // namespace xgboost diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 363321709..1bbc49cf3 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -1,5 +1,5 @@ /*! - * Copyright (c) by Contributors 2019-2021 + * Copyright (c) by Contributors 2019-2022 */ #include #include @@ -255,6 +255,10 @@ TEST(Json, Null) { json = Json::Load({null_input.c_str(), null_input.size()}); ASSERT_TRUE(IsA(json["key"])); + + std::string dumped; + Json::Dump(json, &dumped, std::ios::binary); + ASSERT_TRUE(IsA(Json::Load(StringView{dumped}, std::ios::binary)["key"])); } TEST(Json, EmptyObject) { @@ -297,6 +301,10 @@ TEST(Json, Boolean) { Json j {Json::Load(StringView{str.c_str(), str.size()})}; ASSERT_EQ(get(j["left_child"]), true); ASSERT_EQ(get(j["right_child"]), false); + + std::string dumped; + Json::Dump(j, &dumped, std::ios::binary); + ASSERT_TRUE(get(Json::Load(StringView{dumped}, std::ios::binary)["left_child"])); } TEST(Json, Indexing) { @@ -532,7 +540,8 @@ TEST(Json, IntVSFloat) { } } -TEST(Json, RoundTrip) { +namespace { +void TestRroundTrip(std::ios::openmode mode) { uint32_t i = 0; SimpleLCG rng; SimpleRealUniformDistribution dist(1.0f, 4096.0f); @@ -541,10 +550,10 @@ TEST(Json, RoundTrip) { float f; std::memcpy(&f, &i, sizeof(f)); - Json jf { f }; + Json jf{f}; std::string str; - Json::Dump(jf, &str); - auto loaded = Json::Load({str.c_str(), str.size()}); + Json::Dump(jf, &str, mode); + auto loaded = Json::Load(StringView{str}, mode); if (XGBOOST_EXPECT(std::isnan(f), false)) { ASSERT_TRUE(std::isnan(get(loaded))); } else { @@ -558,6 +567,12 @@ TEST(Json, RoundTrip) { } } } +} // namespace + +TEST(Json, RoundTrip) { + TestRroundTrip(std::ios::out); + TestRroundTrip(std::ios::binary); +} TEST(Json, DISABLED_RoundTripExhaustive) { auto test = [](uint32_t i) { @@ -580,4 +595,87 @@ TEST(Json, DISABLED_RoundTripExhaustive) { test(static_cast(i)); } } + +TEST(Json, TypedArray) { + size_t n = 16; + F32Array f32{n}; + std::iota(f32.GetArray().begin(), f32.GetArray().end(), -8); + U8Array u8{n}; + std::iota(u8.GetArray().begin(), u8.GetArray().end(), 0); + I32Array i32{n}; + std::iota(i32.GetArray().begin(), i32.GetArray().end(), -8); + I64Array i64{n}; + std::iota(i64.GetArray().begin(), i64.GetArray().end(), -8); + + Json json{Object{}}; + json["u8"] = std::move(u8); + ASSERT_TRUE(IsA(json["u8"])); + json["f32"] = std::move(f32); + ASSERT_TRUE(IsA(json["f32"])); + json["i32"] = std::move(i32); + ASSERT_TRUE(IsA(json["i32"])); + json["i64"] = std::move(i64); + ASSERT_TRUE(IsA(json["i64"])); + + std::string str; + Json::Dump(json, &str); + { + auto loaded = Json::Load(StringView{str}); + // for text output there's no typed array. + ASSERT_TRUE(IsA(loaded["u8"])); + auto const& arr = loaded["f32"]; + for (int32_t i = -8; i < 8; ++i) { + ASSERT_EQ(get(arr[i + 8]), i); + } + } + + std::string binary; + Json::Dump(json, &binary, std::ios::binary); + { + auto loaded = Json::Load(StringView{binary}, std::ios::binary); + ASSERT_TRUE(IsA(loaded["u8"])); + auto const& arr = get(loaded["f32"]); + for (int32_t i = -8; i < 8; ++i) { + ASSERT_EQ(arr[i + 8], i); + } + } +} + +TEST(UBJson, Basic) { + auto run_test = [](StringView str) { + auto json = Json::Load(str); + std::vector stream; + UBJWriter writer{&stream}; + Json::Dump(json, &writer); + { + std::ofstream fout{"test.ubj", std::ios::binary | std::ios::out}; + fout.write(stream.data(), stream.size()); + } + + auto data = common::LoadSequentialFile("test.ubj"); + UBJReader reader{StringView{data}}; + json = reader.Load(); + return json; + }; + { + // empty + auto ret = run_test(R"({})"); + std::stringstream ss; + ss << ret; + ASSERT_EQ(ss.str(), "{}"); + } + { + auto ret = run_test(R"({"":[]})"); + std::stringstream ss; + ss << ret; + ASSERT_EQ(ss.str(), R"({"":[]})"); + } + { + // basic + auto ret = run_test(R"({"test": [2.71, 3.14, Infinity]})"); + ASSERT_TRUE(std::isinf(get(get(ret["test"])[2]))); + ASSERT_FLOAT_EQ(3.14, get(get(ret["test"])[1])); + ASSERT_FLOAT_EQ(2.71, get(get(ret["test"])[0])); + } +} } // namespace xgboost diff --git a/tests/cpp/test_serialization.cc b/tests/cpp/test_serialization.cc index 38954f638..f971d70a1 100644 --- a/tests/cpp/test_serialization.cc +++ b/tests/cpp/test_serialization.cc @@ -1,15 +1,25 @@ -// Copyright (c) 2019-2020 by Contributors +// Copyright (c) 2019-2022 by Contributors #include #include #include #include #include #include +#include #include "helpers.h" #include "../../src/common/io.h" #include "../../src/common/random.h" namespace xgboost { +template +void CompareIntArray(Json l, Json r) { + auto const& l_arr = get(l); + auto const& r_arr = get(r); + ASSERT_EQ(l_arr.size(), r_arr.size()); + for (size_t i = 0; i < l_arr.size(); ++i) { + ASSERT_EQ(l_arr[i], r_arr[i]); + } +} void CompareJSON(Json l, Json r) { switch (l.GetValue().Type()) { @@ -45,6 +55,27 @@ void CompareJSON(Json l, Json r) { } break; } + case Value::ValueKind::kNumberArray: { + auto const& l_arr = get(l); + auto const& r_arr = get(r); + ASSERT_EQ(l_arr.size(), r_arr.size()); + for (size_t i = 0; i < l_arr.size(); ++i) { + ASSERT_NEAR(l_arr[i], r_arr[i], kRtEps); + } + break; + } + case Value::ValueKind::kU8Array: { + CompareIntArray(l, r); + break; + } + case Value::ValueKind::kI32Array: { + CompareIntArray(l, r); + break; + } + case Value::ValueKind::kI64Array: { + CompareIntArray(l, r); + break; + } case Value::ValueKind::kBoolean: { ASSERT_EQ(l, r); break;