diff --git a/include/xgboost/json.h b/include/xgboost/json.h index a5872ec3a..77ca6a510 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 by XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #ifndef XGBOOST_JSON_H_ #define XGBOOST_JSON_H_ @@ -42,7 +42,8 @@ class Value { kBoolean, kNull, // typed array for ubjson - kNumberArray, + kF32Array, + kF64Array, kU8Array, kI32Array, kI64Array @@ -173,7 +174,11 @@ class JsonTypedArray : public Value { /** * @brief Typed UBJSON array for 32-bit floating point. */ -using F32Array = JsonTypedArray; +using F32Array = JsonTypedArray; +/** + * @brief Typed UBJSON array for 64-bit floating point. + */ +using F64Array = JsonTypedArray; /** * @brief Typed UBJSON array for uint8_t. */ @@ -457,9 +462,9 @@ class Json { Json& operator[](int ind) const { return (*ptr_)[ind]; } /*! \brief Return the reference to stored Json value. */ - Value const& GetValue() const & { return *ptr_; } - Value const& GetValue() && { return *ptr_; } - Value& GetValue() & { return *ptr_; } + [[nodiscard]] Value const& GetValue() const& { return *ptr_; } + Value const& GetValue() && { return *ptr_; } + Value& GetValue() & { return *ptr_; } bool operator==(Json const& rhs) const { return *ptr_ == *(rhs.ptr_); @@ -472,7 +477,7 @@ class Json { return os; } - IntrusivePtr const& Ptr() const { return ptr_; } + [[nodiscard]] IntrusivePtr const& Ptr() const { return ptr_; } private: IntrusivePtr ptr_{new JsonNull}; diff --git a/include/xgboost/json_io.h b/include/xgboost/json_io.h index 3a73d170a..ce3d25c37 100644 --- a/include/xgboost/json_io.h +++ b/include/xgboost/json_io.h @@ -142,6 +142,7 @@ class JsonWriter { virtual void Visit(JsonArray const* arr); virtual void Visit(F32Array const* arr); + virtual void Visit(F64Array const*) { LOG(FATAL) << "Only UBJSON format can handle f64 array."; } virtual void Visit(U8Array const* arr); virtual void Visit(I32Array const* arr); virtual void Visit(I64Array const* arr); @@ -244,7 +245,8 @@ class UBJReader : public JsonReader { */ class UBJWriter : public JsonWriter { void Visit(JsonArray const* arr) override; - void Visit(F32Array const* arr) override; + void Visit(F32Array const* arr) override; + void Visit(F64Array const* arr) override; void Visit(U8Array const* arr) override; void Visit(I32Array const* arr) override; void Visit(I64Array const* arr) override; diff --git a/src/common/json.cc b/src/common/json.cc index 21be2a5bc..2887eeccf 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -1,11 +1,12 @@ /** - * Copyright 2019-2023, XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include "xgboost/json.h" #include // for array #include // for isdigit #include // for isinf, isnan +#include // for uint8_t, uint16_t, uint32_t #include // for EOF #include // for size_t, strtof #include // for memcpy @@ -72,15 +73,16 @@ void JsonWriter::Visit(JsonNumber const* num) { } void JsonWriter::Visit(JsonInteger const* num) { - char i2s_buffer_[NumericLimits::kToCharsSize]; + std::array::kToCharsSize> i2s_buffer_; auto i = num->GetInteger(); - auto ret = to_chars(i2s_buffer_, i2s_buffer_ + NumericLimits::kToCharsSize, i); + auto ret = + to_chars(i2s_buffer_.data(), i2s_buffer_.data() + NumericLimits::kToCharsSize, i); auto end = ret.ptr; CHECK(ret.ec == std::errc()); - auto digits = std::distance(i2s_buffer_, end); + auto digits = std::distance(i2s_buffer_.data(), end); auto ori_size = stream_->size(); stream_->resize(ori_size + digits); - std::memcpy(stream_->data() + ori_size, i2s_buffer_, digits); + std::memcpy(stream_->data() + ori_size, i2s_buffer_.data(), digits); } void JsonWriter::Visit(JsonNull const* ) { @@ -143,8 +145,10 @@ std::string Value::TypeStr() const { return "Null"; case ValueKind::kInteger: return "Integer"; - case ValueKind::kNumberArray: + case ValueKind::kF32Array: return "F32Array"; + case ValueKind::kF64Array: + return "F64Array"; case ValueKind::kU8Array: return "U8Array"; case ValueKind::kI32Array: @@ -262,10 +266,11 @@ bool JsonTypedArray::operator==(Value const& rhs) const { return std::equal(arr.cbegin(), arr.cend(), vec_.cbegin()); } -template class JsonTypedArray; -template class JsonTypedArray; -template class JsonTypedArray; -template class JsonTypedArray; +template class JsonTypedArray; +template class JsonTypedArray; +template class JsonTypedArray; +template class JsonTypedArray; +template class JsonTypedArray; // Json Number bool JsonNumber::operator==(Value const& rhs) const { @@ -708,6 +713,8 @@ Json UBJReader::ParseArray() { switch (type) { case 'd': return ParseTypedArray(n); + case 'D': + return ParseTypedArray(n); case 'U': return ParseTypedArray(n); case 'l': @@ -797,6 +804,10 @@ Json UBJReader::Parse() { auto v = this->ReadPrimitive(); return Json{v}; } + case 'D': { + auto v = this->ReadPrimitive(); + return Json{v}; + } case 'S': { auto str = this->DecodeStr(); return Json{str}; @@ -825,10 +836,6 @@ Json UBJReader::Parse() { Integer::Int i = this->ReadPrimitive(); return Json{i}; } - case 'D': { - LOG(FATAL) << "f64 is not supported."; - break; - } case 'H': { LOG(FATAL) << "High precision number is not supported."; break; @@ -882,6 +889,8 @@ void WriteTypedArray(JsonTypedArray const* arr, std::vector* stre stream->push_back('$'); if (std::is_same::value) { stream->push_back('d'); + } else if (std::is_same_v) { + stream->push_back('D'); } else if (std::is_same::value) { stream->push_back('i'); } else if (std::is_same::value) { @@ -910,6 +919,7 @@ void WriteTypedArray(JsonTypedArray const* arr, std::vector* stre } void UBJWriter::Visit(F32Array const* arr) { WriteTypedArray(arr, stream_); } +void UBJWriter::Visit(F64Array const* arr) { WriteTypedArray(arr, stream_); } void UBJWriter::Visit(U8Array const* arr) { WriteTypedArray(arr, stream_); } void UBJWriter::Visit(I32Array const* arr) { WriteTypedArray(arr, stream_); } void UBJWriter::Visit(I64Array const* arr) { WriteTypedArray(arr, stream_); } diff --git a/tests/cpp/common/test_json.cc b/tests/cpp/common/test_json.cc index 155cf04ad..72163efd7 100644 --- a/tests/cpp/common/test_json.cc +++ b/tests/cpp/common/test_json.cc @@ -639,6 +639,40 @@ TEST(Json, TypedArray) { ASSERT_EQ(arr[i + 8], i); } } + + { + Json f64{Object{}}; + auto array = F64Array(); + auto& vec = array.GetArray(); + // Construct test data + vec.resize(18); + std::iota(vec.begin(), vec.end(), 0.0); + // special values + vec.push_back(std::numeric_limits::epsilon()); + vec.push_back(std::numeric_limits::max()); + vec.push_back(std::numeric_limits::min()); + vec.push_back(std::numeric_limits::denorm_min()); + vec.push_back(std::numeric_limits::quiet_NaN()); + + static_assert( + std::is_same_v::value_type>); + + f64["f64"] = std::move(array); + ASSERT_TRUE(IsA(f64["f64"])); + std::vector out; + Json::Dump(f64, &out, std::ios::binary); + + auto loaded = Json::Load(StringView{out.data(), out.size()}, std::ios::binary); + ASSERT_TRUE(IsA(loaded["f64"])); + auto const& result = get(loaded["f64"]); + + auto& vec1 = get(f64["f64"]); + ASSERT_EQ(result.size(), vec1.size()); + for (std::size_t i = 0; i < vec1.size() - 1; ++i) { + ASSERT_EQ(result[i], vec1[i]); + } + ASSERT_TRUE(std::isnan(result.back())); + } } TEST(UBJson, Basic) { @@ -694,6 +728,7 @@ TEST(UBJson, Basic) { } } + TEST(Json, TypeCheck) { Json config{Object{}}; config["foo"] = String{"bar"}; diff --git a/tests/cpp/test_serialization.cc b/tests/cpp/test_serialization.cc index 0b65220ab..283a56fc5 100644 --- a/tests/cpp/test_serialization.cc +++ b/tests/cpp/test_serialization.cc @@ -60,7 +60,7 @@ void CompareJSON(Json l, Json r) { } break; } - case Value::ValueKind::kNumberArray: { + case Value::ValueKind::kF32Array: { auto const& l_arr = get(l); auto const& r_arr = get(r); ASSERT_EQ(l_arr.size(), r_arr.size()); @@ -69,6 +69,15 @@ void CompareJSON(Json l, Json r) { } break; } + case Value::ValueKind::kF64Array: { + auto const& l_arr = get(l); + auto const& r_arr = get(r); + ASSERT_EQ(l_arr.size(), r_arr.size()); + for (size_t i = 0; i < l_arr.size(); ++i) { + ASSERT_NEAR(l_arr[i], r_arr[i], kRtEps); + } + break; + } case Value::ValueKind::kU8Array: { CompareIntArray(l, r); break;