Support f64 for ubjson. (#10055)
This commit is contained in:
parent
8ea705e4d5
commit
2e4ea5ecc0
@ -1,5 +1,5 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2023 by XGBoost Contributors
|
* Copyright 2019-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_JSON_H_
|
#ifndef XGBOOST_JSON_H_
|
||||||
#define XGBOOST_JSON_H_
|
#define XGBOOST_JSON_H_
|
||||||
@ -42,7 +42,8 @@ class Value {
|
|||||||
kBoolean,
|
kBoolean,
|
||||||
kNull,
|
kNull,
|
||||||
// typed array for ubjson
|
// typed array for ubjson
|
||||||
kNumberArray,
|
kF32Array,
|
||||||
|
kF64Array,
|
||||||
kU8Array,
|
kU8Array,
|
||||||
kI32Array,
|
kI32Array,
|
||||||
kI64Array
|
kI64Array
|
||||||
@ -173,7 +174,11 @@ class JsonTypedArray : public Value {
|
|||||||
/**
|
/**
|
||||||
* @brief Typed UBJSON array for 32-bit floating point.
|
* @brief Typed UBJSON array for 32-bit floating point.
|
||||||
*/
|
*/
|
||||||
using F32Array = JsonTypedArray<float, Value::ValueKind::kNumberArray>;
|
using F32Array = JsonTypedArray<float, Value::ValueKind::kF32Array>;
|
||||||
|
/**
|
||||||
|
* @brief Typed UBJSON array for 64-bit floating point.
|
||||||
|
*/
|
||||||
|
using F64Array = JsonTypedArray<double, Value::ValueKind::kF64Array>;
|
||||||
/**
|
/**
|
||||||
* @brief Typed UBJSON array for uint8_t.
|
* @brief Typed UBJSON array for uint8_t.
|
||||||
*/
|
*/
|
||||||
@ -457,7 +462,7 @@ class Json {
|
|||||||
Json& operator[](int ind) const { return (*ptr_)[ind]; }
|
Json& operator[](int ind) const { return (*ptr_)[ind]; }
|
||||||
|
|
||||||
/*! \brief Return the reference to stored Json value. */
|
/*! \brief Return the reference to stored Json value. */
|
||||||
Value const& GetValue() const & { return *ptr_; }
|
[[nodiscard]] Value const& GetValue() const& { return *ptr_; }
|
||||||
Value const& GetValue() && { return *ptr_; }
|
Value const& GetValue() && { return *ptr_; }
|
||||||
Value& GetValue() & { return *ptr_; }
|
Value& GetValue() & { return *ptr_; }
|
||||||
|
|
||||||
@ -472,7 +477,7 @@ class Json {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
IntrusivePtr<Value> const& Ptr() const { return ptr_; }
|
[[nodiscard]] IntrusivePtr<Value> const& Ptr() const { return ptr_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
IntrusivePtr<Value> ptr_{new JsonNull};
|
IntrusivePtr<Value> ptr_{new JsonNull};
|
||||||
|
|||||||
@ -142,6 +142,7 @@ class JsonWriter {
|
|||||||
|
|
||||||
virtual void Visit(JsonArray const* arr);
|
virtual void Visit(JsonArray const* arr);
|
||||||
virtual void Visit(F32Array const* arr);
|
virtual void Visit(F32Array const* arr);
|
||||||
|
virtual void Visit(F64Array const*) { LOG(FATAL) << "Only UBJSON format can handle f64 array."; }
|
||||||
virtual void Visit(U8Array const* arr);
|
virtual void Visit(U8Array const* arr);
|
||||||
virtual void Visit(I32Array const* arr);
|
virtual void Visit(I32Array const* arr);
|
||||||
virtual void Visit(I64Array const* arr);
|
virtual void Visit(I64Array const* arr);
|
||||||
@ -245,6 +246,7 @@ class UBJReader : public JsonReader {
|
|||||||
class UBJWriter : public JsonWriter {
|
class UBJWriter : public JsonWriter {
|
||||||
void Visit(JsonArray const* arr) override;
|
void Visit(JsonArray const* arr) override;
|
||||||
void Visit(F32Array const* arr) override;
|
void Visit(F32Array const* arr) override;
|
||||||
|
void Visit(F64Array const* arr) override;
|
||||||
void Visit(U8Array const* arr) override;
|
void Visit(U8Array const* arr) override;
|
||||||
void Visit(I32Array const* arr) override;
|
void Visit(I32Array const* arr) override;
|
||||||
void Visit(I64Array const* arr) override;
|
void Visit(I64Array const* arr) override;
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2019-2023, XGBoost Contributors
|
* Copyright 2019-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
#include <array> // for array
|
#include <array> // for array
|
||||||
#include <cctype> // for isdigit
|
#include <cctype> // for isdigit
|
||||||
#include <cmath> // for isinf, isnan
|
#include <cmath> // for isinf, isnan
|
||||||
|
#include <cstdint> // for uint8_t, uint16_t, uint32_t
|
||||||
#include <cstdio> // for EOF
|
#include <cstdio> // for EOF
|
||||||
#include <cstdlib> // for size_t, strtof
|
#include <cstdlib> // for size_t, strtof
|
||||||
#include <cstring> // for memcpy
|
#include <cstring> // for memcpy
|
||||||
@ -72,15 +73,16 @@ void JsonWriter::Visit(JsonNumber const* num) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void JsonWriter::Visit(JsonInteger const* num) {
|
void JsonWriter::Visit(JsonInteger const* num) {
|
||||||
char i2s_buffer_[NumericLimits<int64_t>::kToCharsSize];
|
std::array<char, NumericLimits<int64_t>::kToCharsSize> i2s_buffer_;
|
||||||
auto i = num->GetInteger();
|
auto i = num->GetInteger();
|
||||||
auto ret = to_chars(i2s_buffer_, i2s_buffer_ + NumericLimits<int64_t>::kToCharsSize, i);
|
auto ret =
|
||||||
|
to_chars(i2s_buffer_.data(), i2s_buffer_.data() + NumericLimits<int64_t>::kToCharsSize, i);
|
||||||
auto end = ret.ptr;
|
auto end = ret.ptr;
|
||||||
CHECK(ret.ec == std::errc());
|
CHECK(ret.ec == std::errc());
|
||||||
auto digits = std::distance(i2s_buffer_, end);
|
auto digits = std::distance(i2s_buffer_.data(), end);
|
||||||
auto ori_size = stream_->size();
|
auto ori_size = stream_->size();
|
||||||
stream_->resize(ori_size + digits);
|
stream_->resize(ori_size + digits);
|
||||||
std::memcpy(stream_->data() + ori_size, i2s_buffer_, digits);
|
std::memcpy(stream_->data() + ori_size, i2s_buffer_.data(), digits);
|
||||||
}
|
}
|
||||||
|
|
||||||
void JsonWriter::Visit(JsonNull const* ) {
|
void JsonWriter::Visit(JsonNull const* ) {
|
||||||
@ -143,8 +145,10 @@ std::string Value::TypeStr() const {
|
|||||||
return "Null";
|
return "Null";
|
||||||
case ValueKind::kInteger:
|
case ValueKind::kInteger:
|
||||||
return "Integer";
|
return "Integer";
|
||||||
case ValueKind::kNumberArray:
|
case ValueKind::kF32Array:
|
||||||
return "F32Array";
|
return "F32Array";
|
||||||
|
case ValueKind::kF64Array:
|
||||||
|
return "F64Array";
|
||||||
case ValueKind::kU8Array:
|
case ValueKind::kU8Array:
|
||||||
return "U8Array";
|
return "U8Array";
|
||||||
case ValueKind::kI32Array:
|
case ValueKind::kI32Array:
|
||||||
@ -262,10 +266,11 @@ bool JsonTypedArray<T, kind>::operator==(Value const& rhs) const {
|
|||||||
return std::equal(arr.cbegin(), arr.cend(), vec_.cbegin());
|
return std::equal(arr.cbegin(), arr.cend(), vec_.cbegin());
|
||||||
}
|
}
|
||||||
|
|
||||||
template class JsonTypedArray<float, Value::ValueKind::kNumberArray>;
|
template class JsonTypedArray<float, Value::ValueKind::kF32Array>;
|
||||||
template class JsonTypedArray<uint8_t, Value::ValueKind::kU8Array>;
|
template class JsonTypedArray<double, Value::ValueKind::kF64Array>;
|
||||||
template class JsonTypedArray<int32_t, Value::ValueKind::kI32Array>;
|
template class JsonTypedArray<std::uint8_t, Value::ValueKind::kU8Array>;
|
||||||
template class JsonTypedArray<int64_t, Value::ValueKind::kI64Array>;
|
template class JsonTypedArray<std::int32_t, Value::ValueKind::kI32Array>;
|
||||||
|
template class JsonTypedArray<std::int64_t, Value::ValueKind::kI64Array>;
|
||||||
|
|
||||||
// Json Number
|
// Json Number
|
||||||
bool JsonNumber::operator==(Value const& rhs) const {
|
bool JsonNumber::operator==(Value const& rhs) const {
|
||||||
@ -708,6 +713,8 @@ Json UBJReader::ParseArray() {
|
|||||||
switch (type) {
|
switch (type) {
|
||||||
case 'd':
|
case 'd':
|
||||||
return ParseTypedArray<F32Array>(n);
|
return ParseTypedArray<F32Array>(n);
|
||||||
|
case 'D':
|
||||||
|
return ParseTypedArray<F64Array>(n);
|
||||||
case 'U':
|
case 'U':
|
||||||
return ParseTypedArray<U8Array>(n);
|
return ParseTypedArray<U8Array>(n);
|
||||||
case 'l':
|
case 'l':
|
||||||
@ -797,6 +804,10 @@ Json UBJReader::Parse() {
|
|||||||
auto v = this->ReadPrimitive<float>();
|
auto v = this->ReadPrimitive<float>();
|
||||||
return Json{v};
|
return Json{v};
|
||||||
}
|
}
|
||||||
|
case 'D': {
|
||||||
|
auto v = this->ReadPrimitive<double>();
|
||||||
|
return Json{v};
|
||||||
|
}
|
||||||
case 'S': {
|
case 'S': {
|
||||||
auto str = this->DecodeStr();
|
auto str = this->DecodeStr();
|
||||||
return Json{str};
|
return Json{str};
|
||||||
@ -825,10 +836,6 @@ Json UBJReader::Parse() {
|
|||||||
Integer::Int i = this->ReadPrimitive<char>();
|
Integer::Int i = this->ReadPrimitive<char>();
|
||||||
return Json{i};
|
return Json{i};
|
||||||
}
|
}
|
||||||
case 'D': {
|
|
||||||
LOG(FATAL) << "f64 is not supported.";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case 'H': {
|
case 'H': {
|
||||||
LOG(FATAL) << "High precision number is not supported.";
|
LOG(FATAL) << "High precision number is not supported.";
|
||||||
break;
|
break;
|
||||||
@ -882,6 +889,8 @@ void WriteTypedArray(JsonTypedArray<T, kind> const* arr, std::vector<char>* stre
|
|||||||
stream->push_back('$');
|
stream->push_back('$');
|
||||||
if (std::is_same<T, float>::value) {
|
if (std::is_same<T, float>::value) {
|
||||||
stream->push_back('d');
|
stream->push_back('d');
|
||||||
|
} else if (std::is_same_v<T, double>) {
|
||||||
|
stream->push_back('D');
|
||||||
} else if (std::is_same<T, int8_t>::value) {
|
} else if (std::is_same<T, int8_t>::value) {
|
||||||
stream->push_back('i');
|
stream->push_back('i');
|
||||||
} else if (std::is_same<T, uint8_t>::value) {
|
} else if (std::is_same<T, uint8_t>::value) {
|
||||||
@ -910,6 +919,7 @@ void WriteTypedArray(JsonTypedArray<T, kind> const* arr, std::vector<char>* stre
|
|||||||
}
|
}
|
||||||
|
|
||||||
void UBJWriter::Visit(F32Array const* arr) { WriteTypedArray(arr, stream_); }
|
void UBJWriter::Visit(F32Array const* arr) { WriteTypedArray(arr, stream_); }
|
||||||
|
void UBJWriter::Visit(F64Array const* arr) { WriteTypedArray(arr, stream_); }
|
||||||
void UBJWriter::Visit(U8Array const* arr) { WriteTypedArray(arr, stream_); }
|
void UBJWriter::Visit(U8Array const* arr) { WriteTypedArray(arr, stream_); }
|
||||||
void UBJWriter::Visit(I32Array const* arr) { WriteTypedArray(arr, stream_); }
|
void UBJWriter::Visit(I32Array const* arr) { WriteTypedArray(arr, stream_); }
|
||||||
void UBJWriter::Visit(I64Array const* arr) { WriteTypedArray(arr, stream_); }
|
void UBJWriter::Visit(I64Array const* arr) { WriteTypedArray(arr, stream_); }
|
||||||
|
|||||||
@ -639,6 +639,40 @@ TEST(Json, TypedArray) {
|
|||||||
ASSERT_EQ(arr[i + 8], i);
|
ASSERT_EQ(arr[i + 8], i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Json f64{Object{}};
|
||||||
|
auto array = F64Array();
|
||||||
|
auto& vec = array.GetArray();
|
||||||
|
// Construct test data
|
||||||
|
vec.resize(18);
|
||||||
|
std::iota(vec.begin(), vec.end(), 0.0);
|
||||||
|
// special values
|
||||||
|
vec.push_back(std::numeric_limits<double>::epsilon());
|
||||||
|
vec.push_back(std::numeric_limits<double>::max());
|
||||||
|
vec.push_back(std::numeric_limits<double>::min());
|
||||||
|
vec.push_back(std::numeric_limits<double>::denorm_min());
|
||||||
|
vec.push_back(std::numeric_limits<double>::quiet_NaN());
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
std::is_same_v<double, typename std::remove_reference_t<decltype(vec)>::value_type>);
|
||||||
|
|
||||||
|
f64["f64"] = std::move(array);
|
||||||
|
ASSERT_TRUE(IsA<F64Array>(f64["f64"]));
|
||||||
|
std::vector<char> out;
|
||||||
|
Json::Dump(f64, &out, std::ios::binary);
|
||||||
|
|
||||||
|
auto loaded = Json::Load(StringView{out.data(), out.size()}, std::ios::binary);
|
||||||
|
ASSERT_TRUE(IsA<F64Array>(loaded["f64"]));
|
||||||
|
auto const& result = get<F64Array const>(loaded["f64"]);
|
||||||
|
|
||||||
|
auto& vec1 = get<F64Array const>(f64["f64"]);
|
||||||
|
ASSERT_EQ(result.size(), vec1.size());
|
||||||
|
for (std::size_t i = 0; i < vec1.size() - 1; ++i) {
|
||||||
|
ASSERT_EQ(result[i], vec1[i]);
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(std::isnan(result.back()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UBJson, Basic) {
|
TEST(UBJson, Basic) {
|
||||||
@ -694,6 +728,7 @@ TEST(UBJson, Basic) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST(Json, TypeCheck) {
|
TEST(Json, TypeCheck) {
|
||||||
Json config{Object{}};
|
Json config{Object{}};
|
||||||
config["foo"] = String{"bar"};
|
config["foo"] = String{"bar"};
|
||||||
|
|||||||
@ -60,7 +60,7 @@ void CompareJSON(Json l, Json r) {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Value::ValueKind::kNumberArray: {
|
case Value::ValueKind::kF32Array: {
|
||||||
auto const& l_arr = get<F32Array const>(l);
|
auto const& l_arr = get<F32Array const>(l);
|
||||||
auto const& r_arr = get<F32Array const>(r);
|
auto const& r_arr = get<F32Array const>(r);
|
||||||
ASSERT_EQ(l_arr.size(), r_arr.size());
|
ASSERT_EQ(l_arr.size(), r_arr.size());
|
||||||
@ -69,6 +69,15 @@ void CompareJSON(Json l, Json r) {
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Value::ValueKind::kF64Array: {
|
||||||
|
auto const& l_arr = get<F64Array const>(l);
|
||||||
|
auto const& r_arr = get<F64Array const>(r);
|
||||||
|
ASSERT_EQ(l_arr.size(), r_arr.size());
|
||||||
|
for (size_t i = 0; i < l_arr.size(); ++i) {
|
||||||
|
ASSERT_NEAR(l_arr[i], r_arr[i], kRtEps);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
case Value::ValueKind::kU8Array: {
|
case Value::ValueKind::kU8Array: {
|
||||||
CompareIntArray<U8Array>(l, r);
|
CompareIntArray<U8Array>(l, r);
|
||||||
break;
|
break;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user