diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 5f7e75fd9..6e8c09b7d 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -69,6 +69,7 @@ #include "../src/learner.cc" #include "../src/logging.cc" #include "../src/common/common.cc" +#include "../src/common/charconv.cc" #include "../src/common/timer.cc" #include "../src/common/host_device_vector.cc" #include "../src/common/hist_util.cc" diff --git a/include/xgboost/json.h b/include/xgboost/json.h index c190505bf..5048bb7ec 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -1,18 +1,18 @@ /*! - * Copyright (c) by Contributors 2019 + * Copyright (c) by XGBoost Contributors 2019-2020 */ #ifndef XGBOOST_JSON_H_ #define XGBOOST_JSON_H_ #include #include -#include #include #include #include #include #include +#include namespace xgboost { @@ -331,13 +331,7 @@ class Json { static Json Load(StringView str); /*! \brief Pass your own JsonReader. */ static Json Load(JsonReader* reader); - /*! \brief Dump json into stream. */ - static void Dump(Json json, std::ostream* stream, - bool pretty = ConsoleLogger::ShouldLog( - ConsoleLogger::LogVerbosity::kDebug)); - static void Dump(Json json, std::string* out, - bool pretty = ConsoleLogger::ShouldLog( - ConsoleLogger::LogVerbosity::kDebug)); + static void Dump(Json json, std::string* out); Json() : ptr_{new JsonNull} {} diff --git a/include/xgboost/json_io.h b/include/xgboost/json_io.h index a58afe172..67a829ee7 100644 --- a/include/xgboost/json_io.h +++ b/include/xgboost/json_io.h @@ -4,7 +4,9 @@ #ifndef XGBOOST_JSON_IO_H_ #define XGBOOST_JSON_IO_H_ #include +#include +#include #include #include #include @@ -15,20 +17,6 @@ #include namespace xgboost { - -template -class FixedPrecisionStreamContainer : public std::basic_stringstream< - char, std::char_traits, Allocator> { - public: - FixedPrecisionStreamContainer() { - this->precision(std::numeric_limits::max_digits10); - this->imbue(std::locale("C")); - this->setf(std::ios::scientific); - } -}; - -using FixedPrecisionStream = FixedPrecisionStreamContainer>; - /* * \brief A json reader, currently error checking and utf-8 is not fully supported. */ @@ -45,13 +33,11 @@ class JsonReader { SourceLocation() = default; size_t Pos() const { return pos_; } - SourceLocation& Forward() { + void Forward() { pos_++; - return *this; } - SourceLocation& Forward(uint32_t n) { + void Forward(uint32_t n) { pos_ += n; - return *this; } } cursor_; @@ -77,14 +63,17 @@ class JsonReader { return ch; } + /* \brief Skip spaces and consume next character. */ char GetNextNonSpaceChar() { SkipSpaces(); return GetNextChar(); } - - char GetChar(char c) { - char result = GetNextNonSpaceChar(); - if (result != c) { Expect(c, result); } + /* \brief Consume next character without first skipping empty space, throw when the next + * character is not the expected one. + */ + char GetConsecutiveChar(char expected_char) { + char result = GetNextChar(); + if (XGBOOST_EXPECT(result != expected_char, false)) { Expect(expected_char, result); } return result; } @@ -95,7 +84,11 @@ class JsonReader { std::string msg = "Expecting: \""; msg += c; msg += "\", got: \""; - msg += std::string {got} + " \""; + if (got == -1) { + msg += "EOF\""; + } else { + msg += std::to_string(got) + " \""; + } Error(msg); } @@ -119,38 +112,16 @@ class JsonReader { class JsonWriter { static constexpr size_t kIndentSize = 2; - FixedPrecisionStream convertor_; size_t n_spaces_; - std::ostream* stream_; - bool pretty_; + std::vector* stream_; public: - JsonWriter(std::ostream* stream, bool pretty) : - n_spaces_{0}, stream_{stream}, pretty_{pretty} {} + explicit JsonWriter(std::vector* stream) : + n_spaces_{0}, stream_{stream} {} virtual ~JsonWriter() = default; - void NewLine() { - if (pretty_) { - *stream_ << u8"\n" << std::string(n_spaces_, ' '); - } - } - - void BeginIndent() { - n_spaces_ += kIndentSize; - } - void EndIndent() { - n_spaces_ -= kIndentSize; - } - - void Write(std::string str) { - *stream_ << str; - } - void Write(StringView str) { - stream_->write(str.c_str(), str.size()); - } - void Save(Json json); virtual void Visit(JsonArray const* arr); diff --git a/src/common/charconv.cc b/src/common/charconv.cc new file mode 100644 index 000000000..259419e8a --- /dev/null +++ b/src/common/charconv.cc @@ -0,0 +1,942 @@ +/*! + * Copyright 2020 by XGBoost Contributors + * + * \brief An implemenation of Ryu algorithm: + * + * https://dl.acm.org/citation.cfm?id=3192369 + * + * The code is adopted from original (half) c implementation: + * https://github.com/ulfjack/ryu.git with some more comments and tidying. License is + * attached below. + * + * Copyright 2018 Ulf Adams + * + * The contents of this file may be used under the terms of the Apache License, + * Version 2.0. + * + * (See accompanying file LICENSE-Apache or copy at + * http: *www.apache.org/licenses/LICENSE-2.0) + * + * Alternatively, the contents of this file may be used under the terms of + * the Boost Software License, Version 1.0. + * (See accompanying file LICENSE-Boost or copy at + * https://www.boost.org/LICENSE_1_0.txt) + * + * Unless required by applicable law or agreed to in writing, this software + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. + */ +#include +#include +#include +#include +#include + +#include "xgboost/logging.h" +#include "charconv.h" + +#if defined(_MSC_VER) +#include +namespace { +inline int32_t __builtin_clzll(uint64_t x) { + return static_cast(__lzcnt64(x)); +} +} // anonymous namespace +#endif + +/* + * We did some cleanup from the original implementation instead of doing line to line + * port. + * + * The basic concept of floating rounding is, for a floating point number, we need to + * convert base2 to base10. During which we need to implement correct rounding. Hence on + * base2 we have: + * + * {low, value, high} + * + * 3 values, representing round down, no rounding, and round up. In the original + * implementation and paper, variables representing these 3 values are typically postfixed + * with m, r, p like {vr, vm, vp}. Here we name them more verbosely. + */ + +namespace xgboost { +namespace detail { +static constexpr char kItoaLut[200] = { + '0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0', + '7', '0', '8', '0', '9', '1', '0', '1', '1', '1', '2', '1', '3', '1', '4', + '1', '5', '1', '6', '1', '7', '1', '8', '1', '9', '2', '0', '2', '1', '2', + '2', '2', '3', '2', '4', '2', '5', '2', '6', '2', '7', '2', '8', '2', '9', + '3', '0', '3', '1', '3', '2', '3', '3', '3', '4', '3', '5', '3', '6', '3', + '7', '3', '8', '3', '9', '4', '0', '4', '1', '4', '2', '4', '3', '4', '4', + '4', '5', '4', '6', '4', '7', '4', '8', '4', '9', '5', '0', '5', '1', '5', + '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5', '9', + '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6', + '7', '6', '8', '6', '9', '7', '0', '7', '1', '7', '2', '7', '3', '7', '4', + '7', '5', '7', '6', '7', '7', '7', '8', '7', '9', '8', '0', '8', '1', '8', + '2', '8', '3', '8', '4', '8', '5', '8', '6', '8', '7', '8', '8', '8', '9', + '9', '0', '9', '1', '9', '2', '9', '3', '9', '4', '9', '5', '9', '6', '9', + '7', '9', '8', '9', '9'}; + +constexpr uint32_t Tens(uint32_t n) { return n == 1 ? 10 : (Tens(n - 1) * 10); } + +struct UnsignedFloatBase2; + +struct UnsignedFloatBase10 { + uint32_t mantissa; + // Decimal exponent's range is -45 to 38 + // inclusive, and can fit in a short if needed. + int32_t exponent; +}; + +template +To BitCast(From&& from) { + static_assert(sizeof(From) == sizeof(To), "Bit cast doesn't change output size."); + To t; + std::memcpy(&t, &from, sizeof(To)); + return t; +} + +struct IEEE754 { + static constexpr uint32_t kFloatMantissaBits = 23; + static constexpr uint32_t kFloatBias = 127; + static constexpr uint32_t kFloatExponentBits = 8; + + static void Decode(float f, UnsignedFloatBase2* uf, bool* signbit); + static float Encode(UnsignedFloatBase2 const& uf, bool signbit); + + static float Infinity(bool sign) { + uint32_t f = + ((static_cast(sign)) + << (IEEE754::kFloatExponentBits + IEEE754::kFloatMantissaBits)) | + (0xffu << IEEE754::kFloatMantissaBits); + float result = BitCast(f); + return result; + } +}; + +struct UnsignedFloatBase2 { + uint32_t mantissa; + // Decimal exponent's range is -45 to 38 + // inclusive, and can fit in a short if needed. + uint32_t exponent; + + bool Infinite() const { + return exponent == ((1u << IEEE754::kFloatExponentBits) - 1u); + } + bool Zero() const { + return mantissa == 0 && exponent == 0; + } +}; + +inline void IEEE754::Decode(float f, UnsignedFloatBase2 *uf, bool *signbit) { + auto bits = BitCast(f); + // Decode bits into sign, mantissa, and exponent. + *signbit = std::signbit(f); + uf->mantissa = bits & ((1u << kFloatMantissaBits) - 1); + uf->exponent = (bits >> IEEE754::kFloatMantissaBits) & + ((1u << IEEE754::kFloatExponentBits) - 1); // remove signbit +} + +inline float IEEE754::Encode(UnsignedFloatBase2 const &uf, bool signbit) { + uint32_t f = + ((((static_cast(signbit)) << IEEE754::kFloatExponentBits) | + static_cast(uf.exponent)) + << IEEE754::kFloatMantissaBits) | + uf.mantissa; + return BitCast(f); +} + +// Represents the interval of information-preserving outputs. +struct MantissaInteval { + int32_t exponent; + // low: smaller half way point + uint32_t mantissa_low; + // correct: f + uint32_t mantissa_correct; + // high: larger half way point + uint32_t mantissa_high; +}; + +struct RyuPowLogUtils { + // This table is generated by PrintFloatLookupTable from ryu. We adopted only the float + // 32 table instead of double full table. + // f2s_full_table.h + uint32_t constexpr static kFloatPow5InvBitcount = 59; + static constexpr uint64_t kFloatPow5InvSplit[55] = { + 576460752303423489u, 461168601842738791u, 368934881474191033u, + 295147905179352826u, 472236648286964522u, 377789318629571618u, + 302231454903657294u, 483570327845851670u, 386856262276681336u, + 309485009821345069u, 495176015714152110u, 396140812571321688u, + 316912650057057351u, 507060240091291761u, 405648192073033409u, + 324518553658426727u, 519229685853482763u, 415383748682786211u, + 332306998946228969u, 531691198313966350u, 425352958651173080u, + 340282366920938464u, 544451787073501542u, 435561429658801234u, + 348449143727040987u, 557518629963265579u, 446014903970612463u, + 356811923176489971u, 570899077082383953u, 456719261665907162u, + 365375409332725730u, 292300327466180584u, 467680523945888934u, + 374144419156711148u, 299315535325368918u, 478904856520590269u, + 383123885216472215u, 306499108173177772u, 490398573077084435u, + 392318858461667548u, 313855086769334039u, 502168138830934462u, + 401734511064747569u, 321387608851798056u, 514220174162876889u, + 411376139330301511u, 329100911464241209u, 526561458342785934u, + 421249166674228747u, 336999333339382998u, 539198933343012796u, + 431359146674410237u, 345087317339528190u, 552139707743245103u, + 441711766194596083u}; + + uint32_t constexpr static kFloatPow5Bitcount = 61; + static constexpr uint64_t kFloatPow5Split[47] = { + 1152921504606846976u, 1441151880758558720u, 1801439850948198400u, + 2251799813685248000u, 1407374883553280000u, 1759218604441600000u, + 2199023255552000000u, 1374389534720000000u, 1717986918400000000u, + 2147483648000000000u, 1342177280000000000u, 1677721600000000000u, + 2097152000000000000u, 1310720000000000000u, 1638400000000000000u, + 2048000000000000000u, 1280000000000000000u, 1600000000000000000u, + 2000000000000000000u, 1250000000000000000u, 1562500000000000000u, + 1953125000000000000u, 1220703125000000000u, 1525878906250000000u, + 1907348632812500000u, 1192092895507812500u, 1490116119384765625u, + 1862645149230957031u, 1164153218269348144u, 1455191522836685180u, + 1818989403545856475u, 2273736754432320594u, 1421085471520200371u, + 1776356839400250464u, 2220446049250313080u, 1387778780781445675u, + 1734723475976807094u, 2168404344971008868u, 1355252715606880542u, + 1694065894508600678u, 2117582368135750847u, 1323488980084844279u, + 1654361225106055349u, 2067951531382569187u, 1292469707114105741u, + 1615587133892632177u, 2019483917365790221u}; + + static uint32_t Pow5Factor(uint32_t value) noexcept(true) { + uint32_t count = 0; + for (;;) { + const uint32_t q = value / 5; + const uint32_t r = value % 5; + if (r != 0) { + break; + } + value = q; + ++count; + } + return count; + } + + // Returns true if value is divisible by 5^p. + static bool MultipleOfPowerOf5(const uint32_t value, const uint32_t p) noexcept(true) { + return Pow5Factor(value) >= p; + } + + // Returns true if value is divisible by 2^p. + static bool MultipleOfPowerOf2(const uint32_t value, const uint32_t p) noexcept(true) { +#ifdef __GNUC__ + return static_cast(__builtin_ctz(value)) >= p; +#else + return (value & ((1u << p) - 1)) == 0; +#endif // __GNUC__ + } + + // Returns e == 0 ? 1 : ceil(log_2(5^e)). + static uint32_t Pow5Bits(const int32_t e) noexcept(true) { + return static_cast(((e * 163391164108059ull) >> 46) + 1); + } + + static int32_t Log2Pow5(const int32_t e) { + // This approximation works up to the point that the multiplication + // overflows at e = 3529. If the multiplication were done in 64 bits, it + // would fail at 5^4004 which is just greater than 2^9297. + assert(e >= 0); + assert(e <= 3528); + return static_cast(((static_cast(e)) * 1217359) >> 19); + } + + static int32_t CeilLog2Pow5(const int32_t e) { + return RyuPowLogUtils::Log2Pow5(e) + 1; + } + + /* + * \brief Multiply 32-bit and 64-bit -> 128 bit, then access the higher bits. + */ + static uint32_t MulShift(const uint32_t x, const uint64_t y, + const int32_t shift) noexcept(true) { + // For 32-bit * 64-bit: x * y, it can be decomposed into: + // + // x * (y_high + y_low) = (x * y_high) + (x * y_low) + // + // For more general case 64-bit * 64-bit, see https://stackoverflow.com/a/1541458 + const uint32_t y_low = static_cast(y); + const uint32_t y_high = static_cast(y >> 32); + + const uint64_t low = static_cast(x) * y_low; + const uint64_t high = static_cast(x) * y_high; + + const uint64_t sum = (low >> 32) + high; + const uint64_t shifted_sum = sum >> (shift - 32); + + return static_cast(shifted_sum); + } + + /* + * \brief floor(5^q/2*k) and shift by j + */ + static uint32_t MulPow5InvDivPow2(const uint32_t m, const uint32_t q, + const int32_t j) noexcept(true) { + return MulShift(m, kFloatPow5InvSplit[q], j); + } + + /* + * \brief floor(2^k/5^q) + 1 and shift by j + */ + static uint32_t MulPow5divPow2(const uint32_t m, const uint32_t i, + const int32_t j) noexcept(true) { + // clang-tidy makes false assumption that can lead to i >= 47, which is impossible. + // Can be verified by enumerating all float32 values. + return MulShift(m, kFloatPow5Split[i], j); // NOLINT + } + + static uint32_t FloorLog2(const uint64_t value) { + return 63 - __builtin_clzll(value); + } + + /* + * \brief floor(e * log_10(2)). + */ + static uint32_t Log10Pow2(const int32_t e) noexcept(true) { + // The first value this approximation fails for is 2^1651 which is just + // greater than 10^297. + assert(e >= 0); + assert(e <= 1 << 15); + return static_cast((static_cast(e) * 169464822037455ull) >> 49); + } + + // Returns floor(e * log_10(5)). + static uint32_t Log10Pow5(const int32_t expoent) noexcept(true) { + // The first value this approximation fails for is 5^2621 which is just + // greater than 10^1832. + assert(expoent >= 0); + assert(expoent <= 1 << 15); + return static_cast( + ((static_cast(expoent)) * 196742565691928ull) >> 48); + } +}; + +constexpr uint64_t RyuPowLogUtils::kFloatPow5InvSplit[55]; +constexpr uint64_t RyuPowLogUtils::kFloatPow5Split[47]; + +class PowerBaseComputer { + private: + static uint8_t + ToDecimalBase(bool const accept_bounds, uint32_t const mantissa_low_shift, + MantissaInteval const base2, MantissaInteval *base10, + bool *mantissa_low_is_trailing_zeros, + bool *mantissa_out_is_trailing_zeros) noexcept(true) { + uint8_t last_removed_digit = 0; + if (base2.exponent >= 0) { + const uint32_t q = RyuPowLogUtils::Log10Pow2(base2.exponent); + base10->exponent = static_cast(q); + const int32_t k = RyuPowLogUtils::kFloatPow5InvBitcount + + RyuPowLogUtils::Pow5Bits(static_cast(q)) - 1; + const int32_t i = -base2.exponent + static_cast(q) + k; + base10->mantissa_low = + RyuPowLogUtils::MulPow5InvDivPow2(base2.mantissa_low, q, i); + base10->mantissa_correct = + RyuPowLogUtils::MulPow5InvDivPow2(base2.mantissa_correct, q, i); + base10->mantissa_high = + RyuPowLogUtils::MulPow5InvDivPow2(base2.mantissa_high, q, i); + + if (q != 0 && + (base10->mantissa_high - 1) / 10 <= base10->mantissa_low / 10) { + // We need to know one removed digit even if we are not going to loop + // below. We could use q = X - 1 above, except that would require 33 + // bits for the result, and we've found that 32-bit arithmetic is + // faster even on 64-bit machines. + const int32_t l = + RyuPowLogUtils::kFloatPow5InvBitcount + + RyuPowLogUtils::Pow5Bits(static_cast(q - 1)) - 1; + last_removed_digit = static_cast( + RyuPowLogUtils::MulPow5InvDivPow2( + base2.mantissa_correct, q - 1, + -base2.exponent + static_cast(q) - 1 + l) % + 10); + } + if (q <= 9) { + // The largest power of 5 that fits in 24 bits is 5^10, but q <= 9 seems to be + // safe as well. Only one of mantissa_high, mantissa_correct, and mantissa_low can + // be a multiple of 5, if any. + if (base2.mantissa_correct % 5 == 0) { + *mantissa_out_is_trailing_zeros = + RyuPowLogUtils::MultipleOfPowerOf5(base2.mantissa_correct, q); + } else if (accept_bounds) { + *mantissa_low_is_trailing_zeros = + RyuPowLogUtils::MultipleOfPowerOf5(base2.mantissa_low, q); + } else { + base10->mantissa_high -= + RyuPowLogUtils::MultipleOfPowerOf5(base2.mantissa_high, q); + } + } + } else { + const uint32_t q = RyuPowLogUtils::Log10Pow5(-base2.exponent); + base10->exponent = static_cast(q) + base2.exponent; + const int32_t i = -base2.exponent - static_cast(q); + const int32_t k = + RyuPowLogUtils::Pow5Bits(i) - RyuPowLogUtils::kFloatPow5Bitcount; + int32_t j = static_cast(q) - k; + base10->mantissa_correct = RyuPowLogUtils::MulPow5divPow2( + base2.mantissa_correct, static_cast(i), j); + base10->mantissa_high = RyuPowLogUtils::MulPow5divPow2( + base2.mantissa_high, static_cast(i), j); + base10->mantissa_low = RyuPowLogUtils::MulPow5divPow2( + base2.mantissa_low, static_cast(i), j); + + if (q != 0 && + (base10->mantissa_high - 1) / 10 <= base10->mantissa_low / 10) { + j = static_cast(q) - 1 - + (RyuPowLogUtils::Pow5Bits(i + 1) - + RyuPowLogUtils::kFloatPow5Bitcount); + last_removed_digit = static_cast( + RyuPowLogUtils::MulPow5divPow2(base2.mantissa_correct, + static_cast(i + 1), j) % + 10); + } + if (q <= 1) { + // {mantissa_out, mantissa_out_high, mantissa_out_low} is trailing zeros if + // {mantissa_correct,mantissa_high,mantissa_low} has at least q trailing 0 + // bits.mantissa_correct = 4 * m2, so it always has at least two trailing 0 bits. + *mantissa_out_is_trailing_zeros = true; + if (accept_bounds) { + // mantissa_low = mantissa_correct - 1 - mantissa_low_shift, so it has 1 + // trailing 0 bit iff mmShift == 1. + *mantissa_low_is_trailing_zeros = mantissa_low_shift == 1; + } else { + // mantissa_high = mantissa_correct + 2, so it always has at least one trailing + // 0 bit. + --base10->mantissa_high; + } + } else if (q < 31) { + *mantissa_out_is_trailing_zeros = + RyuPowLogUtils::MultipleOfPowerOf2(base2.mantissa_correct, q - 1); + } + } + return last_removed_digit; + } + + /* + * \brief A varient of extended euclidean GCD algorithm. + */ + static UnsignedFloatBase10 + ShortestRepresentation(bool mantissa_low_is_trailing_zeros, + bool mantissa_out_is_trailing_zeros, + uint8_t last_removed_digit, bool const accept_bounds, + MantissaInteval base10) noexcept(true) { + int32_t removed {0}; + uint32_t output {0}; + + if (mantissa_low_is_trailing_zeros || mantissa_out_is_trailing_zeros) { + // General case, which happens rarely (~4.0%). + while (base10.mantissa_high / 10 > base10.mantissa_low / 10) { + mantissa_low_is_trailing_zeros &= base10.mantissa_low % 10 == 0; + mantissa_out_is_trailing_zeros &= last_removed_digit == 0; + last_removed_digit = static_cast(base10.mantissa_correct % 10); + base10.mantissa_correct /= 10; + base10.mantissa_high /= 10; + base10.mantissa_low /= 10; + ++removed; + } + + if (mantissa_low_is_trailing_zeros) { + while (base10.mantissa_low % 10 == 0) { + mantissa_out_is_trailing_zeros &= last_removed_digit == 0; + last_removed_digit = static_cast(base10.mantissa_correct % 10); + base10.mantissa_correct /= 10; + base10.mantissa_high /= 10; + base10.mantissa_low /= 10; + ++removed; + } + } + + if (mantissa_out_is_trailing_zeros && last_removed_digit == 5 && + base10.mantissa_correct % 2 == 0) { + // Round even if the exact number is .....50..0. + last_removed_digit = 4; + } + // We need to take mantissa_out + 1 if mantissa_out is outside bounds or we need to + // round up. + output = base10.mantissa_correct + + ((base10.mantissa_correct == base10.mantissa_low && + (!accept_bounds || !mantissa_low_is_trailing_zeros)) || + last_removed_digit >= 5); + } else { + // Specialized for the common case (~96.0%). Percentages below are + // relative to this. Loop iterations below (approximately): 0: 13.6%, + // 1: 70.7%, 2: 14.1%, 3: 1.39%, 4: 0.14%, 5+: 0.01% + while (base10.mantissa_high / 10 > base10.mantissa_low / 10) { + last_removed_digit = static_cast(base10.mantissa_correct % 10); + base10.mantissa_correct /= 10; + base10.mantissa_high /= 10; + base10.mantissa_low /= 10; + ++removed; + } + + // We need to take mantissa_out + 1 if mantissa_out is outside bounds or we need to + // round up. + output = base10.mantissa_correct + + (base10.mantissa_correct == base10.mantissa_low || + last_removed_digit >= 5); + } + const int32_t exp = base10.exponent + removed; + + UnsignedFloatBase10 fd; + fd.exponent = exp; + fd.mantissa = output; + return fd; + } + + public: + static UnsignedFloatBase10 Binary2Decimal(UnsignedFloatBase2 const f) noexcept(true) { + MantissaInteval base2_range; + uint32_t mantissa_base2; + if (f.exponent == 0) { + // We subtract 2 so that the bounds computation has 2 additional bits. + base2_range.exponent = static_cast(1) - + static_cast(IEEE754::kFloatBias) - + static_cast(IEEE754::kFloatMantissaBits) - + static_cast(2); + static_assert(static_cast(1) - + static_cast(IEEE754::kFloatBias) - + static_cast(IEEE754::kFloatMantissaBits) - + static_cast(2) == + -151, + ""); + mantissa_base2 = f.mantissa; + } else { + base2_range.exponent = static_cast(f.exponent) - IEEE754::kFloatBias - + IEEE754::kFloatMantissaBits - 2; + mantissa_base2 = (1u << IEEE754::kFloatMantissaBits) | f.mantissa; + } + const bool even = (mantissa_base2 & 1) == 0; + const bool accept_bounds = even; + + // Step 2: Determine the interval of valid decimal representations. + base2_range.mantissa_correct = 4 * mantissa_base2; + base2_range.mantissa_high = 4 * mantissa_base2 + 2; + // Implicit bool -> int conversion. True is 1, false is 0. + const uint32_t mantissa_low_shift = f.mantissa != 0 || f.exponent <= 1; + base2_range.mantissa_low = 4 * mantissa_base2 - 1 - mantissa_low_shift; + + // Step 3: Convert to a decimal power base using 64-bit arithmetic. + MantissaInteval base10_range; + bool mantissa_low_is_trailing_zeros = false; + bool mantissa_out_is_trailing_zeros = false; + auto last_removed_digit = PowerBaseComputer::ToDecimalBase( + accept_bounds, mantissa_low_shift, base2_range, &base10_range, + &mantissa_low_is_trailing_zeros, &mantissa_out_is_trailing_zeros); + + // Step 4: Find the shortest decimal representation in the interval of valid + // representations. + auto out = ShortestRepresentation(mantissa_low_is_trailing_zeros, + mantissa_out_is_trailing_zeros, + last_removed_digit, + accept_bounds, base10_range); + return out; + } +}; + +/* + * \brief Print the floating point number in base 10. + */ +class RyuPrinter { + private: + static inline uint32_t OutputLength(const uint32_t v) noexcept(true) { + // Function precondition: v is not a 10-digit number. + // (f2s: 9 digits are sufficient for round-tripping.) + // (d2fixed: We print 9-digit blocks.) + static_assert(100000000 == Tens(8), ""); + assert(v < Tens(9)); + if (v >= Tens(8)) { + return 9; + } + if (v >= Tens(7)) { + return 8; + } + if (v >= Tens(6)) { + return 7; + } + if (v >= Tens(5)) { + return 6; + } + if (v >= Tens(4)) { + return 5; + } + if (v >= Tens(3)) { + return 4; + } + if (v >= Tens(2)) { + return 3; + } + if (v >= Tens(1)) { + return 2; + } + return 1; + } + + public: + static int32_t PrintBase10Float(UnsignedFloatBase10 v, const bool sign, + char *const result) noexcept(true) { + // Step 5: Print the decimal representation. + int index = 0; + if (sign) { + result[index++] = '-'; + } + + uint32_t output = v.mantissa; + const uint32_t out_length = OutputLength(output); + + // Print the decimal digits. + // The following code is equivalent to: + // for (uint32_t i = 0; i < olength - 1; ++i) { + // const uint32_t c = output % 10; output /= 10; + // result[index + olength - i] = (char) ('0' + c); + // } + // result[index] = '0' + output % 10; + uint32_t i = 0; + while (output >= Tens(4)) { + const uint32_t c = output % Tens(4); + output /= Tens(4); + const uint32_t c0 = (c % 100) << 1; + const uint32_t c1 = (c / 100) << 1; + // This is used to speed up decimal digit generation by copying + // pairs of digits into the final output. + std::memcpy(result + index + out_length - i - 1, kItoaLut + c0, 2); + std::memcpy(result + index + out_length - i - 3, kItoaLut + c1, 2); + i += 4; + } + if (output >= 100) { + const uint32_t c = (output % 100) << 1; + output /= 100; + std::memcpy(result + index + out_length - i - 1, kItoaLut + c, 2); + i += 2; + } + if (output >= 10) { + const uint32_t c = output << 1; + // We can't use std::memcpy here: the decimal dot goes between these two + // digits. + result[index + out_length - i] = kItoaLut[c + 1]; + result[index] = kItoaLut[c]; + } else { + result[index] = static_cast('0' + output); + } + + // Print decimal point if needed. + if (out_length > 1) { + result[index + 1] = '.'; + index += out_length + 1; + } else { + ++index; + } + + // Print the exponent. + result[index++] = 'E'; + int32_t exp = v.exponent + static_cast(out_length) - 1; + if (exp < 0) { + result[index++] = '-'; + exp = -exp; + } + + if (exp >= 10) { + std::memcpy(result + index, kItoaLut + 2 * exp, 2); + index += 2; + } else { + result[index++] = static_cast('0' + exp); + } + + return index; + } + + static int32_t PrintSpecialFloat(const bool sign, UnsignedFloatBase2 f, + char *const result) noexcept(true) { + if (f.mantissa) { + std::memcpy(result, u8"NaN", 3); + return 3; + } + if (sign) { + result[0] = '-'; + } + if (f.exponent) { + std::memcpy(result + sign, u8"Infinity", 8); + return sign + 8; + } + std::memcpy(result + sign, u8"0E0", 3); + return sign + 3; + } +}; + +int32_t ToCharsFloatImpl(float f, char * const result) { + // Step 1: Decode the floating-point number, and unify normalized and + // subnormal cases. + UnsignedFloatBase2 uf32; + bool sign; + IEEE754::Decode(f, &uf32, &sign); + + // Case distinction; exit early for the easy cases. + if (uf32.Infinite() || uf32.Zero()) { + return RyuPrinter::PrintSpecialFloat(sign, uf32, result); + } + + const UnsignedFloatBase10 v = PowerBaseComputer::Binary2Decimal(uf32); + const auto index = RyuPrinter::PrintBase10Float(v, sign, result); + return index; +} + + +// ====================== Integer ================== + +// This is an implementation for base 10 inspired by the one in libstdc++v3. The general +// scheme is by decomposing the value into multiple combination of base (which is 10) by +// mod, until the value is lesser than 10, then last char is just char '0' (ascii 48) plus +// that value. Other popular implementations can be found in RapidJson and libc++ (in +// llvm-project), which uses the same general work flow with the same look up table, but +// probably with better performance as they are more complicated. +void ItoaUnsignedImpl(char *first, uint32_t length, uint64_t value) { + uint32_t position = length - 1; + while (value >= Tens(2)) { + auto const num = (value % Tens(2)) * 2; + value /= Tens(2); + first[position] = kItoaLut[num + 1]; + first[position - 1] = kItoaLut[num]; + position -= 2; + } + if (value >= 10) { + auto const num = value * 2; + first[0] = kItoaLut[num]; + first[1] = kItoaLut[num + 1]; + } else { + first[0]= '0' + value; + } +} + +constexpr uint32_t ShortestDigit10Impl(uint64_t value, uint32_t n) { + // Should trigger tail recursion optimization. + return value < 10 ? n : + (value < Tens(2) ? n + 1 : + (value < Tens(3) ? n + 2 : + (value < Tens(4) ? n + 3 : + ShortestDigit10Impl(value / Tens(4), n + 4)))); +} + +constexpr uint32_t ShortestDigit10(uint64_t value) { + return ShortestDigit10Impl(value, 1); +} + +to_chars_result ToCharsUnsignedImpl(char *first, char *last, + uint64_t const value) { + const uint32_t output_len = ShortestDigit10(value); + to_chars_result ret; + if (XGBOOST_EXPECT(std::distance(first, last) == 0, false)) { + ret.ec = std::errc::value_too_large; + ret.ptr = last; + return ret; + } + + ItoaUnsignedImpl(first, output_len, value); + ret.ptr = first + output_len; + ret.ec = std::errc(); + return ret; +} + +/* + * The parsing is also part of ryu. As of writing, the implementation in ryu uses full + * double table. But here we optimize the table size with float table instead. The + * result is exactly the same. + */ +from_chars_result FromCharFloatImpl(const char *buffer, const int len, + float *result) { + if (len == 0) { + return {buffer, std::errc::invalid_argument}; + } + int32_t m10digits = 0; + int32_t e10digits = 0; + int32_t dot_ind = len; + int32_t e_ind = len; + uint32_t mantissa_b10 = 0; + int32_t exp_b10 = 0; + bool signed_mantissa = false; + bool signed_exp = false; + int32_t i = 0; + if (buffer[i] == '-') { + signed_mantissa = true; + i++; + } + for (; i < len; i++) { + char c = buffer[i]; + if (c == '.') { + if (dot_ind != len) { + return {buffer + i, std::errc::invalid_argument}; + } + dot_ind = i; + continue; + } + if ((c < '0') || (c > '9')) { + break; + } + if (m10digits >= 9) { + return {buffer + i, std::errc::result_out_of_range}; + } + mantissa_b10 = 10 * mantissa_b10 + (c - '0'); + if (mantissa_b10 != 0) { + m10digits++; + } + } + + if (i < len && ((buffer[i] == 'e') || (buffer[i] == 'E'))) { + e_ind = i; + i++; + if (i < len && ((buffer[i] == '-') || (buffer[i] == '+'))) { + signed_exp = buffer[i] == '-'; + i++; + } + for (; i < len; i++) { + char c = buffer[i]; + if ((c < '0') || (c > '9')) { + return {buffer + i, std::errc::invalid_argument}; + } + if (e10digits > 3) { + return {buffer + i, std::errc::result_out_of_range}; + } + exp_b10 = 10 * exp_b10 + (c - '0'); + if (exp_b10 != 0) { + e10digits++; + } + } + } + if (i < len) { + return {buffer + i, std::errc::invalid_argument}; + } + if (signed_exp) { + exp_b10 = -exp_b10; + } + exp_b10 -= dot_ind < e_ind ? e_ind - dot_ind - 1 : 0; + if (mantissa_b10 == 0) { + *result = signed_mantissa ? -0.0f : 0.0f; + return {}; + } + + if ((m10digits + exp_b10 <= -46) || (mantissa_b10 == 0)) { + // Number is less than 1e-46, which should be rounded down to 0; return + // +/-0.0. + uint32_t ieee = + (static_cast(signed_mantissa)) + << (IEEE754::kFloatExponentBits + IEEE754::kFloatMantissaBits); + *result = BitCast(ieee); + return {}; + } + if (m10digits + exp_b10 >= 40) { + // Number is larger than 1e+39, which should be rounded to +/-Infinity. + *result = IEEE754::Infinity(signed_mantissa); + return {}; + } + + // Convert to binary float m2 * 2^e2, while retaining information about + // whether the conversion was exact (trailingZeros). + int32_t exp_b2; + uint32_t mantissa_b2; + bool trailing_zeros; + if (exp_b10 >= 0) { + // The length of m * 10^e in bits is: + // log2(m10 * 10^e10) = log2(m10) + e10 log2(10) = log2(m10) + e10 + e10 * + // log2(5) + // + // We want to compute the IEEE754::kFloatMantissaBits + 1 top-most bits (+1 for the + // implicit leading one in IEEE format). We therefore choose a binary output + // exponent of + // log2(m10 * 10^e10) - (IEEE754::kFloatMantissaBits + 1). + // + // We use floor(log2(5^e10)) so that we get at least this many bits; better + // to have an additional bit than to not have enough bits. + exp_b2 = RyuPowLogUtils::FloorLog2(mantissa_b10) + exp_b10 + + RyuPowLogUtils::Log2Pow5(exp_b10) - + (IEEE754::kFloatMantissaBits + 1); + + // We now compute [m10 * 10^e10 / 2^e2] = [m10 * 5^e10 / 2^(e2-e10)]. + // To that end, we use the RyuPowLogUtils::kFloatPow5Bitcount table. + int j = exp_b2 - exp_b10 - RyuPowLogUtils::CeilLog2Pow5(exp_b10) + + RyuPowLogUtils::kFloatPow5Bitcount; + assert(j >= 0); + mantissa_b2 = RyuPowLogUtils::MulPow5divPow2(mantissa_b10, exp_b10, j); + + // We also compute if the result is exact, i.e., + // [m10 * 10^e10 / 2^e2] == m10 * 10^e10 / 2^e2. + // This can only be the case if 2^e2 divides m10 * 10^e10, which in turn + // requires that the largest power of 2 that divides m10 + e10 is greater + // than e2. If e2 is less than e10, then the result must be exact. Otherwise + // we use the existing multipleOfPowerOf2 function. + trailing_zeros = + exp_b2 < exp_b10 || + (exp_b2 - exp_b10 < 32 && + RyuPowLogUtils::MultipleOfPowerOf2(mantissa_b10, exp_b2 - exp_b10)); + } else { + exp_b2 = RyuPowLogUtils::FloorLog2(mantissa_b10) + exp_b10 - + RyuPowLogUtils::CeilLog2Pow5(-exp_b10) - + (IEEE754::kFloatMantissaBits + 1); + + // We now compute [m10 * 10^e10 / 2^e2] = [m10 / (5^(-e10) 2^(e2-e10))]. + int j = exp_b2 - exp_b10 + RyuPowLogUtils::CeilLog2Pow5(-exp_b10) - 1 + + RyuPowLogUtils::kFloatPow5InvBitcount; + mantissa_b2 = RyuPowLogUtils::MulPow5InvDivPow2(mantissa_b10, -exp_b10, j); + + // We also compute if the result is exact, i.e., + // [m10 / (5^(-e10) 2^(e2-e10))] == m10 / (5^(-e10) 2^(e2-e10)) + // + // If e2-e10 >= 0, we need to check whether (5^(-e10) 2^(e2-e10)) divides + // m10, which is the case iff pow5(m10) >= -e10 AND pow2(m10) >= e2-e10. + // + // If e2-e10 < 0, we have actually computed [m10 * 2^(e10 e2) / 5^(-e10)] + // above, and we need to check whether 5^(-e10) divides (m10 * 2^(e10-e2)), + // which is the case iff pow5(m10 * 2^(e10-e2)) = pow5(m10) >= -e10. + trailing_zeros = + (exp_b2 < exp_b10 || + (exp_b2 - exp_b10 < 32 && RyuPowLogUtils::MultipleOfPowerOf2( + mantissa_b10, exp_b2 - exp_b10))) && + RyuPowLogUtils::MultipleOfPowerOf5(mantissa_b10, -exp_b10); + } + + // Compute the final IEEE exponent. + uint32_t f_e2 = + std::max(static_cast(0), + static_cast(exp_b2 + IEEE754::kFloatBias + + RyuPowLogUtils::FloorLog2(mantissa_b2))); + + if (f_e2 > 0xfe) { + // Final IEEE exponent is larger than the maximum representable; return + // +/-Infinity. + *result = IEEE754::Infinity(signed_mantissa); + return {}; + } + + // We need to figure out how much we need to shift m2. The tricky part is that + // we need to take the final IEEE exponent into account, so we need to reverse + // the bias and also special-case the value 0. + int32_t shift = (f_e2 == 0 ? 1 : f_e2) - exp_b2 - IEEE754::kFloatBias - + IEEE754::kFloatMantissaBits; + assert(shift >= 0); + + // We need to round up if the exact value is more than 0.5 above the value we + // computed. That's equivalent to checking if the last removed bit was 1 and + // either the value was not just trailing zeros or the result would otherwise + // be odd. + // + // We need to update trailingZeros given that we have the exact output + // exponent ieee_e2 now. + trailing_zeros &= (mantissa_b2 & ((1u << (shift - 1)) - 1)) == 0; + uint32_t lastRemovedBit = (mantissa_b2 >> (shift - 1)) & 1; + bool roundup = (lastRemovedBit != 0) && + (!trailing_zeros || (((mantissa_b2 >> shift) & 1) != 0)); + + uint32_t f_m2 = (mantissa_b2 >> shift) + roundup; + assert(f_m2 <= (1u << (IEEE754::kFloatMantissaBits + 1))); + f_m2 &= (1u << IEEE754::kFloatMantissaBits) - 1; + if (f_m2 == 0 && roundup) { + // Rounding up may overflow the mantissa. + // In this case we move a trailing zero of the mantissa into the exponent. + // Due to how the IEEE represents +/-Infinity, we don't need to check for + // overflow here. + f_e2++; + } + *result = IEEE754::Encode({f_m2, f_e2}, signed_mantissa); + return {}; +} +} // namespace detail +} // namespace xgboost diff --git a/src/common/charconv.h b/src/common/charconv.h new file mode 100644 index 000000000..b931ed7ce --- /dev/null +++ b/src/common/charconv.h @@ -0,0 +1,103 @@ +/*! + * Copyright 2019 by XGBoost Contributors + * + * \brief Implement `std::to_chars` and `std::from_chars` for float. Only base 10 with + * scientific format is supported. The implementation guarantees roundtrip + * reproducibility. + */ +#ifndef XGBOOST_COMMON_CHARCONV_H_ +#define XGBOOST_COMMON_CHARCONV_H_ + +#include +#include +#include +#include + +#include "xgboost/logging.h" + +namespace xgboost { + +struct to_chars_result { // NOLINT + char* ptr; + std::errc ec; +}; + +struct from_chars_result { // NOLINT + const char *ptr; + std::errc ec; +}; + +namespace detail { +int32_t ToCharsFloatImpl(float f, char * const result); +to_chars_result ToCharsUnsignedImpl(char *first, char *last, + uint64_t const value); +from_chars_result FromCharFloatImpl(const char *buffer, const int len, + float *result); +} // namespace detail + +template +struct NumericLimits; + +template <> struct NumericLimits { + // Unlike std::numeric_limit::max_digits10, which represents the **minimum** + // length of base10 digits that are necessary to uniquely represent all distinct values. + // This value is used to represent the maximum length. As sign bit occupies 1 character: + // sign + len(str(2^24)) + decimal point + `E` + sign + len(str(2^8)) + '\0' + static constexpr size_t kToCharsSize = 16; +}; + +template <> struct NumericLimits { + // From llvm libcxx: numeric_limits::digits10 returns value less on 1 than desired for + // unsigned numbers. For example, for 1-byte unsigned value digits10 is 2 (999 can not + // be represented), so we need +1 here. + static constexpr size_t kToCharsSize = + std::numeric_limits::digits10 + + 3; // +1 for minus, +1 for digits10, +1 for '\0' just to be safe. +}; + +inline to_chars_result to_chars(char *first, char *last, float value) { // NOLINT + if (XGBOOST_EXPECT(!(static_cast(last - first) >= + NumericLimits::kToCharsSize), + false)) { + return {first, std::errc::value_too_large}; + } + auto index = detail::ToCharsFloatImpl(value, first); + to_chars_result ret; + ret.ptr = first + index; + + if (XGBOOST_EXPECT(ret.ptr < last, true)) { + ret.ec = std::errc(); + } else { + ret.ec = std::errc::value_too_large; + ret.ptr = last; + } + return ret; +} + +inline to_chars_result to_chars(char *first, char *last, int64_t value) { // NOLINT + if (XGBOOST_EXPECT(first == last, false)) { + return {first, std::errc::value_too_large}; + } + // first write '-' and convert to unsigned, then write the rest. + if (value == 0) { + *first = '0'; + return {std::next(first), std::errc()}; + } + uint64_t unsigned_value = value; + if (value < 0) { + *first = '-'; + std::advance(first, 1); + unsigned_value = uint64_t(~value) + uint64_t(1); + } + return detail::ToCharsUnsignedImpl(first, last, unsigned_value); +} + +inline from_chars_result from_chars(const char *buffer, const char *end, // NOLINT + float &value) { // NOLINT + from_chars_result res = + detail::FromCharFloatImpl(buffer, std::distance(buffer, end), &value); + return res; +} +} // namespace xgboost + +#endif // XGBOOST_COMMON_CHARCONV_H_ diff --git a/src/common/json.cc b/src/common/json.cc index ca46623d4..6ba82aa91 100644 --- a/src/common/json.cc +++ b/src/common/json.cc @@ -1,12 +1,15 @@ /*! - * Copyright (c) by Contributors 2019 + * Copyright (c) by Contributors 2019-2020 */ #include +#include +#include #include #include #include #include +#include "charconv.h" #include "xgboost/base.h" #include "xgboost/logging.h" #include "xgboost/json.h" @@ -19,56 +22,68 @@ void JsonWriter::Save(Json json) { } void JsonWriter::Visit(JsonArray const* arr) { - this->Write("["); + stream_->emplace_back('['); auto const& vec = arr->GetArray(); size_t size = vec.size(); for (size_t i = 0; i < size; ++i) { auto const& value = vec[i]; this->Save(value); - if (i != size-1) { Write(","); } + if (i != size - 1) { + stream_->emplace_back(','); + } } - this->Write("]"); + stream_->emplace_back(']'); } void JsonWriter::Visit(JsonObject const* obj) { - this->Write("{"); - this->BeginIndent(); - this->NewLine(); - + stream_->emplace_back('{'); size_t i = 0; size_t size = obj->GetObject().size(); for (auto& value : obj->GetObject()) { - this->Write("\"" + value.first + "\":"); + auto s = String{value.first}; + this->Visit(&s); + stream_->emplace_back(':'); this->Save(value.second); if (i != size-1) { - this->Write(","); - this->NewLine(); + stream_->emplace_back(','); } i++; } - this->EndIndent(); - this->NewLine(); - this->Write("}"); + + stream_->emplace_back('}'); } void JsonWriter::Visit(JsonNumber const* num) { - convertor_ << num->GetNumber(); - auto const& str = convertor_.str(); - this->Write(StringView{str.c_str(), str.size()}); - convertor_.str(""); + char number[NumericLimits::kToCharsSize]; + auto res = to_chars(number, number + sizeof(number), num->GetNumber()); + auto end = res.ptr; + auto ori_size = stream_->size(); + stream_->resize(stream_->size() + end - number); + std::memcpy(stream_->data() + ori_size, number, end - number); } void JsonWriter::Visit(JsonInteger const* num) { - convertor_ << num->GetInteger(); - auto const& str = convertor_.str(); - this->Write(StringView{str.c_str(), str.size()}); - convertor_.str(""); + char i2s_buffer_[NumericLimits::kToCharsSize]; + auto i = num->GetInteger(); + auto ret = to_chars(i2s_buffer_, i2s_buffer_ + NumericLimits::kToCharsSize, i); + auto end = ret.ptr; + CHECK(ret.ec == std::errc()); + auto digits = std::distance(i2s_buffer_, end); + auto ori_size = stream_->size(); + stream_->resize(ori_size + digits); + std::memcpy(stream_->data() + ori_size, i2s_buffer_, digits); } void JsonWriter::Visit(JsonNull const* null) { - this->Write("null"); + auto s = stream_->size(); + stream_->resize(s + 4); + auto& buf = (*stream_); + buf[s + 0] = 'n'; + buf[s + 1] = 'u'; + buf[s + 2] = 'l'; + buf[s + 3] = 'l'; } void JsonWriter::Visit(JsonString const* str) { @@ -105,15 +120,30 @@ void JsonWriter::Visit(JsonString const* str) { } } buffer += '"'; - this->Write(buffer); + + auto s = stream_->size(); + stream_->resize(s + buffer.size()); + std::memcpy(stream_->data() + s, buffer.data(), buffer.size()); } void JsonWriter::Visit(JsonBoolean const* boolean) { bool val = boolean->GetBoolean(); + auto s = stream_->size(); if (val) { - this->Write(u8"true"); + stream_->resize(s + 4); + auto& buf = (*stream_); + buf[s + 0] = 't'; + buf[s + 1] = 'r'; + buf[s + 2] = 'u'; + buf[s + 3] = 'e'; } else { - this->Write(u8"false"); + stream_->resize(s + 5); + auto& buf = (*stream_); + buf[s + 0] = 'f'; + buf[s + 1] = 'a'; + buf[s + 2] = 'l'; + buf[s + 3] = 's'; + buf[s + 4] = 'e'; } } @@ -310,7 +340,7 @@ Value & JsonNull::operator=(Value const &rhs) { } void JsonNull::Save(JsonWriter* writer) { - writer->Write("null"); + writer->Visit(this); } // Json Boolean @@ -354,7 +384,7 @@ Json JsonReader::Parse() { } else if ( c == '[' ) { return ParseArray(); } else if ( c == '-' || std::isdigit(c) || - c == 'N' ) { + c == 'N' || c == 'I') { // For now we only accept `NaN`, not `nan` as the later violiates LR(1) with `null`. return ParseNumber(); } else if ( c == '\"' ) { @@ -379,9 +409,13 @@ void JsonReader::Error(std::string msg) const { // just copy it. std::istringstream str_s(raw_str_.substr(0, raw_str_.size())); - msg += ", around character: " + std::to_string(cursor_.Pos()); + msg += ", around character position: " + std::to_string(cursor_.Pos()); msg += '\n'; + if (cursor_.Pos() == 0) { + LOG(FATAL) << msg << ", \"" << str_s.str() << " \""; + } + constexpr size_t kExtend = 8; auto beg = static_cast(cursor_.Pos()) - static_cast(kExtend) < 0 ? 0 : cursor_.Pos() - kExtend; @@ -413,11 +447,15 @@ void JsonReader::Error(std::string msg) const { LOG(FATAL) << msg; } +namespace { +bool IsSpace(char c) { return c == ' ' || c == '\n' || c == '\r' || c == '\t'; } +} // anonymous namespace + // Json class void JsonReader::SkipSpaces() { while (cursor_.Pos() < raw_str_.size()) { char c = raw_str_[cursor_.Pos()]; - if (std::isspace(c)) { + if (IsSpace(c)) { cursor_.Forward(); } else { break; @@ -438,7 +476,7 @@ void ParseStr(std::string const& str) { } Json JsonReader::ParseString() { - char ch { GetChar('\"') }; // NOLINT + char ch { GetConsecutiveChar('\"') }; // NOLINT std::ostringstream output; std::string str; while (true) { @@ -483,14 +521,14 @@ Json JsonReader::ParseNull() { Json JsonReader::ParseArray() { std::vector data; - char ch { GetChar('[') }; // NOLINT + char ch { GetConsecutiveChar('[') }; // NOLINT while (true) { if (PeekNextChar() == ']') { - GetChar(']'); + GetConsecutiveChar(']'); return Json(std::move(data)); } auto obj = Parse(); - data.push_back(obj); + data.emplace_back(obj); ch = GetNextNonSpaceChar(); if (ch == ']') break; if (ch != ',') { @@ -502,14 +540,14 @@ Json JsonReader::ParseArray() { } Json JsonReader::ParseObject() { - GetChar('{'); + GetConsecutiveChar('{'); std::map data; SkipSpaces(); char ch = PeekNextChar(); if (ch == '}') { - GetChar('}'); + GetConsecutiveChar('}'); return Json(std::move(data)); } @@ -550,118 +588,105 @@ Json JsonReader::ParseNumber() { char const* const beg = p; // keep track of current pointer // TODO(trivialfis): Add back all the checks for number - bool negative = false; if (XGBOOST_EXPECT(*p == 'N', false)) { - GetChar('N'); - GetChar('a'); - GetChar('N'); + GetConsecutiveChar('N'); + GetConsecutiveChar('a'); + GetConsecutiveChar('N'); return Json(static_cast(std::numeric_limits::quiet_NaN())); } - if ('-' == *p) { - ++p; + bool negative = false; + switch (*p) { + case '-': { negative = true; + ++p; + break; + } + case '+': { + negative = false; + ++p; + break; + } + default: { + break; + } + } + + if (XGBOOST_EXPECT(*p == 'I', false)) { + cursor_.Forward(std::distance(beg, p)); // +/- + for (auto i : {'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'}) { + GetConsecutiveChar(i); + } + auto f = std::numeric_limits::infinity(); + if (negative) { + f = -f; + } + return Json(static_cast(f)); } bool is_float = false; - using ExpInt = std::remove_const< - decltype(std::numeric_limits::max_exponent)>::type; - constexpr auto kExpMax = std::numeric_limits::max(); - constexpr auto kExpMin = std::numeric_limits::min(); - - JsonInteger::Int i = 0; - double f = 0.0; // Use double to maintain accuracy + int64_t i = 0; if (*p == '0') { - ++p; - } else { - char c = *p; - do { - ++p; - char digit = c - '0'; - i = 10 * i + digit; - c = *p; - } while (std::isdigit(c)); + i = 0; + p++; } - ExpInt exponent = 0; - const char *const dot_position = p; - if ('.' == *p) { + while (XGBOOST_EXPECT(*p >= '0' && *p <= '9', true)) { + i = i * 10 + (*p - '0'); + p++; + } + + if (*p == '.') { + p++; is_float = true; - f = i; - ++p; - char c = *p; - do { - ++p; - f = f * 10 + (c - '0'); - c = *p; - } while (std::isdigit(c)); - } - if (is_float) { - exponent = dot_position - p + 1; + while (*p >= '0' && *p <= '9') { + i = i * 10 + (*p - '0'); + p++; + } } - char e = *p; - if ('e' == e || 'E' == e) { - if (!is_float) { - is_float = true; - f = i; - } - ++p; + if (*p == 'E' || *p == 'e') { + is_float = true; + p++; - bool negative_exponent = false; - if ('-' == *p) { - negative_exponent = true; - ++p; - } else if ('+' == *p) { - ++p; + switch (*p) { + case '-': + case '+': { + p++; + break; + } + default: + break; } - ExpInt exp = 0; - - char c = *p; - while (std::isdigit(c)) { - unsigned char digit = c - '0'; - if (XGBOOST_EXPECT(exp > (kExpMax - digit) / 10, false)) { - CHECK_GT(exp, (kExpMax - digit) / 10) << "Overflow"; + if (XGBOOST_EXPECT(*p >= '0' && *p <= '9', true)) { + p++; + while (*p >= '0' && *p <= '9') { + p++; } - exp = 10 * exp + digit; - ++p; - c = *p; + } else { + Error("Expecting digit"); } - static_assert(-kExpMax >= kExpMin, "exp can be negated without loss or UB"); - exponent += (negative_exponent ? -exp : exp); - } - - if (exponent) { - CHECK(is_float); - // If d is zero but the exponent is huge, don't - // multiply zero by inf which gives nan. - if (f != 0.0) { - // Only use exp10 from libc on gcc+linux -#if !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) || !defined(__linux__) -#define exp10(val) std::pow(10, (val)) -#endif // !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) || !defined(__linux__) - f *= exp10(exponent); -#if !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) || !defined(__linux__) -#undef exp10 -#endif // !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) || !defined(__linux__) - } - } - - if (negative) { - f = -f; - i = -i; } auto moved = std::distance(beg, p); this->cursor_.Forward(moved); if (is_float) { + float f; + auto ret = from_chars(beg, p, f); + if (XGBOOST_EXPECT(ret.ec != std::errc(), false)) { + // Compatible with old format that generates very long mantissa from std stream. + f = std::strtof(beg, nullptr); + } return Json(static_cast(f)); } else { + if (negative) { + i = -i; + } return Json(JsonInteger(i)); } } @@ -674,20 +699,15 @@ Json JsonReader::ParseBoolean() { std::string buffer; if (ch == 't') { - for (size_t i = 0; i < 3; ++i) { - buffer.push_back(GetNextNonSpaceChar()); - } - if (buffer != u8"rue") { - Error("Expecting boolean value \"true\"."); - } + GetConsecutiveChar('r'); + GetConsecutiveChar('u'); + GetConsecutiveChar('e'); result = true; } else { - for (size_t i = 0; i < 4; ++i) { - buffer.push_back(GetNextNonSpaceChar()); - } - if (buffer != u8"alse") { - Error("Expecting boolean value \"false\"."); - } + GetConsecutiveChar('a'); + GetConsecutiveChar('l'); + GetConsecutiveChar('s'); + GetConsecutiveChar('e'); result = false; } return Json{JsonBoolean{result}}; @@ -704,16 +724,12 @@ Json Json::Load(JsonReader* reader) { return json; } -void Json::Dump(Json json, std::ostream *stream, bool pretty) { - JsonWriter writer(stream, pretty); +void Json::Dump(Json json, std::string* str) { + std::vector buffer; + JsonWriter writer(&buffer); writer.Save(json); -} - -void Json::Dump(Json json, std::string* str, bool pretty) { - std::stringstream ss; - JsonWriter writer(&ss, pretty); - writer.Save(json); - *str = ss.str(); + str->resize(buffer.size()); + std::copy(buffer.cbegin(), buffer.cend(), str->begin()); } Json& Json::operator=(Json const &other) = default; diff --git a/src/common/observer.h b/src/common/observer.h index e83c6c80b..1af16d45d 100644 --- a/src/common/observer.h +++ b/src/common/observer.h @@ -53,7 +53,7 @@ class TrainingObserver { Json j_tree {Object()}; tree.SaveModel(&j_tree); std::string str; - Json::Dump(j_tree, &str, true); + Json::Dump(j_tree, &str); OBSERVER_PRINT << str << OBSERVER_ENDL; } /*\brief Observe tree. */ diff --git a/src/common/timer.cc b/src/common/timer.cc index 49d08a35c..79c823dd7 100644 --- a/src/common/timer.cc +++ b/src/common/timer.cc @@ -61,9 +61,8 @@ std::vector Monitor::CollectFromOtherRanks() const { kv.second.timer.elapsed).count())); } - std::stringstream ss; - Json::Dump(j_statistic, &ss); - std::string const str { ss.str() }; + std::string str; + Json::Dump(j_statistic, &str); size_t str_size = str.size(); rabit::Allreduce(&str_size, 1); diff --git a/tests/cpp/common/test_charconv.cc b/tests/cpp/common/test_charconv.cc new file mode 100644 index 000000000..cce48f76f --- /dev/null +++ b/tests/cpp/common/test_charconv.cc @@ -0,0 +1,213 @@ +/* + * The code is adopted from original (half) c implementation: + * https://github.com/ulfjack/ryu.git with some more comments and tidying. License is + * attached below. + * + * Copyright 2018 Ulf Adams + * + * The contents of this file may be used under the terms of the Apache License, + * Version 2.0. + * + * (See accompanying file LICENSE-Apache or copy at + * http: *www.apache.org/licenses/LICENSE-2.0) + * + * Alternatively, the contents of this file may be used under the terms of + * the Boost Software License, Version 1.0. + * (See accompanying file LICENSE-Boost or copy at + * https://www.boost.org/LICENSE_1_0.txt) + * + * Unless required by applicable law or agreed to in writing, this software + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. + */ +#include +#include +#include +#include "../../../src/common/charconv.h" + +namespace xgboost { +namespace { +void TestInteger(char const* res, int64_t i) { + char result[xgboost::NumericLimits::kToCharsSize]; + auto ret = to_chars(result, result + sizeof(result), i); + *ret.ptr = '\0'; + EXPECT_STREQ(res, result); +} + +static float Int32Bits2Float(uint32_t bits) { + float f; + memcpy(&f, &bits, sizeof(float)); + return f; +} + +void TestRyu(char const *res, float v) { + char result[xgboost::NumericLimits::kToCharsSize]; + auto ret = to_chars(result, result + sizeof(result), v); + *ret.ptr = '\0'; + EXPECT_STREQ(res, result); +} +} // anonymous namespace + +TEST(Ryu, Subnormal) { + TestRyu("0E0", 0.0f); + TestRyu("-0E0", -0.0f); + TestRyu("1E0", 1.0f); + TestRyu("-1E0", -1.0f); + TestRyu("NaN", NAN); + TestRyu("Infinity", INFINITY); + TestRyu("-Infinity", -INFINITY); + + TestRyu("1E-45", std::numeric_limits::denorm_min()); +} + +TEST(Ryu, Denormal) { + TestRyu("1E-45", std::numeric_limits::denorm_min()); +} + +TEST(Ryu, SwitchToSubnormal) { + TestRyu("1.1754944E-38", 1.1754944E-38f); +} + +TEST(Ryu, MinAndMax) { + TestRyu("3.4028235E38", Int32Bits2Float(0x7f7fffff)); + TestRyu("1E-45", Int32Bits2Float(1)); +} + +// Check that we return the exact boundary if it is the shortest +// representation, but only if the original floating point number is even. +TEST(Ryu, BoundaryRoundEven) { + TestRyu("3.355445E7", 3.355445E7f); + TestRyu("9E9", 8.999999E9f); + TestRyu("3.436672E10", 3.4366717E10f); +} + +// If the exact value is exactly halfway between two shortest representations, +// then we round to even. It seems like this only makes a difference if the +// last two digits are ...2|5 or ...7|5, and we cut off the 5. +TEST(Ryu, ExactValueRoundEven) { + TestRyu("3.0540412E5", 3.0540412E5f); + TestRyu("8.0990312E3", 8.0990312E3f); +} + +TEST(Ryu, LotsOfTrailingZeros) { + // Pattern for the first test: 00111001100000000000000000000000 + TestRyu("2.4414062E-4", 2.4414062E-4f); + TestRyu("2.4414062E-3", 2.4414062E-3f); + TestRyu("4.3945312E-3", 4.3945312E-3f); + TestRyu("6.3476562E-3", 6.3476562E-3f); +} + +TEST(Ryu, Regression) { + TestRyu("4.7223665E21", 4.7223665E21f); + TestRyu("8.388608E6", 8388608.0f); + TestRyu("1.6777216E7", 1.6777216E7f); + TestRyu("3.3554436E7", 3.3554436E7f); + TestRyu("6.7131496E7", 6.7131496E7f); + TestRyu("1.9310392E-38", 1.9310392E-38f); + TestRyu("-2.47E-43", -2.47E-43f); + TestRyu("1.993244E-38", 1.993244E-38f); + TestRyu("4.1039004E3", 4103.9003f); + TestRyu("5.3399997E9", 5.3399997E9f); + TestRyu("6.0898E-39", 6.0898E-39f); + TestRyu("1.0310042E-3", 0.0010310042f); + TestRyu("2.882326E17", 2.8823261E17f); + TestRyu("7.038531E-26", 7.0385309E-26f); + TestRyu("9.223404E17", 9.2234038E17f); + TestRyu("6.710887E7", 6.7108872E7f); + TestRyu("1E-44", 1.0E-44f); + TestRyu("2.816025E14", 2.816025E14f); + TestRyu("9.223372E18", 9.223372E18f); + TestRyu("1.5846086E29", 1.5846085E29f); + TestRyu("1.1811161E19", 1.1811161E19f); + TestRyu("5.368709E18", 5.368709E18f); + TestRyu("4.6143166E18", 4.6143165E18f); + TestRyu("7.812537E-3", 0.007812537f); + TestRyu("1E-45", 1.4E-45f); + TestRyu("1.18697725E20", 1.18697724E20f); + TestRyu("1.00014165E-36", 1.00014165E-36f); + TestRyu("2E2", 200.0f); + TestRyu("3.3554432E7", 3.3554432E7f); + + static_assert(1.1920929E-7f == std::numeric_limits::epsilon(), ""); + TestRyu("1.1920929E-7", std::numeric_limits::epsilon()); +} + +TEST(Ryu, RoundTrip) { + float f = -1.1493590134238582e-40; + char result[NumericLimits::kToCharsSize] { 0 }; + auto ret = to_chars(result, result + sizeof(result), f); + size_t dis = std::distance(result, ret.ptr); + float back; + auto from_ret = from_chars(result, result + dis, back); + ASSERT_EQ(from_ret.ec, std::errc()); + std::string str; + for (size_t i = 0; i < dis; ++i) { + str.push_back(result[i]); + } + ASSERT_EQ(f, back); +} + +TEST(Ryu, LooksLikePow5) { + // These numbers have a mantissa that is the largest power of 5 that fits, + // and an exponent that causes the computation for q to result in 10, which is a corner + // case for Ryu. + TestRyu("6.7108864E17", Int32Bits2Float(0x5D1502F9)); + TestRyu("1.3421773E18", Int32Bits2Float(0x5D9502F9)); + TestRyu("2.6843546E18", Int32Bits2Float(0x5E1502F9)); +} + +TEST(Ryu, OutputLength) { + TestRyu("1E0", 1.0f); // already tested in Basic + TestRyu("1.2E0", 1.2f); + TestRyu("1.23E0", 1.23f); + TestRyu("1.234E0", 1.234f); + TestRyu("1.2345E0", 1.2345f); + TestRyu("1.23456E0", 1.23456f); + TestRyu("1.234567E0", 1.234567f); + TestRyu("1.2345678E0", 1.2345678f); + TestRyu("1.23456735E-36", 1.23456735E-36f); +} + +TEST(IntegerPrinting, Basic) { + TestInteger("0", 0); + auto str = std::to_string(std::numeric_limits::min()); + TestInteger(str.c_str(), std::numeric_limits::min()); + str = std::to_string(std::numeric_limits::max()); + TestInteger(str.c_str(), std::numeric_limits::max()); +} + +void TestRyuParse(float f, std::string in) { + float res; + auto ret = from_chars(in.c_str(), in.c_str() + in.size(), res); + ASSERT_EQ(ret.ec, std::errc()); + ASSERT_EQ(f, res); +} + +TEST(Ryu, Basic) { + TestRyuParse(0.0f, "0"); + TestRyuParse(-0.0f, "-0"); + TestRyuParse(1.0f, "1"); + TestRyuParse(-1.0f, "-1"); + TestRyuParse(123456792.0f, "123456789"); + TestRyuParse(299792448.0f, "299792458"); +} + +TEST(Ryu, MinMax) { + TestRyuParse(1e-45f, "1e-45"); + TestRyuParse(FLT_MIN, "1.1754944e-38"); + TestRyuParse(FLT_MAX, "3.4028235e+38"); +} + +TEST(Ryu, MantissaRoundingOverflow) { + TestRyuParse(1.0f, "0.999999999"); + TestRyuParse(INFINITY, "3.4028236e+38"); + TestRyuParse(1.1754944e-38f, "1.17549430e-38"); // FLT_MIN +} + +TEST(Ryu, TrailingZeros) { + TestRyuParse(26843550.0f, "26843549.5"); + TestRyuParse(50000004.0f, "50000002.5"); + TestRyuParse(99999992.0f, "99999989.5"); +} + +} // namespace xgboost diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 14ed78dda..365306fb8 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -248,9 +248,9 @@ TEST(HistUtil, AdapterDeviceSketch) { thrust::device_vector< float> data(rows*cols); auto json_array_interface = Generate2dArrayInterface(rows, cols, "{ 1.0,2.0,3.0,4.0,5.0 }; - std::stringstream ss; - Json::Dump(json_array_interface, &ss); - std::string str = ss.str(); + std::string str; + Json::Dump(json_array_interface, &str); + data::CupyAdapter adapter(str); auto device_cuts = AdapterDeviceSketch(&adapter, num_bins, missing); diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 08405e9f5..55edb324f 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -54,9 +54,8 @@ inline data::CupyAdapter AdapterFromData(const thrust::device_vector &x, array_interface["data"] = j_data; array_interface["version"] = Integer(static_cast(1)); array_interface["typestr"] = String("(json), 31.8892f, kRtEps); + ASSERT_EQ(get(json), 31.8892f); } { std::string str = "-31.8892"; auto json = Json::Load(StringView{str.c_str(), str.size()}); - ASSERT_NEAR(get(json), -31.8892f, kRtEps); + ASSERT_EQ(get(json), -31.8892f); } { std::string str = "2e4"; auto json = Json::Load(StringView{str.c_str(), str.size()}); - ASSERT_NEAR(get(json), 2e4f, kRtEps); + ASSERT_EQ(get(json), 2e4f); } { std::string str = "2e-4"; auto json = Json::Load(StringView{str.c_str(), str.size()}); - ASSERT_NEAR(get(json), 2e-4f, kRtEps); + ASSERT_EQ(get(json), 2e-4f); + } + { + std::string str = "-2e-4"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); + ASSERT_EQ(get(json), -2e-4f); + } + { + std::string str = "-0.0"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); + ASSERT_TRUE(std::signbit(get(json))); + ASSERT_EQ(get(json), -0); + } + { + std::string str = "-5.37645816802978516e-01"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); + ASSERT_TRUE(std::signbit(get(json))); + // Larger than fast path limit. + ASSERT_EQ(get(json), -5.37645816802978516e-01); + } + { + std::string str = "9.86623668670654297e+00"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); + ASSERT_FALSE(std::signbit(get(json))); + ASSERT_EQ(get(json), 9.86623668670654297e+00); } } @@ -200,13 +226,30 @@ TEST(Json, ParseArray) { Json v0 = arr[0]; ASSERT_EQ(get(v0["depth"]), 3); ASSERT_NEAR(get(v0["gain"]), 10.4866, kRtEps); + + { + std::string str = + "[5.04713470458984375e+02,9.86623668670654297e+00,4.94847229003906250e+" + "02,2.13924217224121094e+00,7.72699451446533203e+00,2." + "30380615234375000e+02,2.64466613769531250e+02]"; + auto json = Json::Load(StringView{str.c_str(), str.size()}); + + auto const& vec = get(json); + ASSERT_EQ(get(vec[0]), 5.04713470458984375e+02); + ASSERT_EQ(get(vec[1]), 9.86623668670654297e+00); + ASSERT_EQ(get(vec[2]), 4.94847229003906250e+02); + ASSERT_EQ(get(vec[3]), 2.13924217224121094e+00); + ASSERT_EQ(get(vec[4]), 7.72699451446533203e+00); + ASSERT_EQ(get(vec[5]), 2.30380615234375000e+02); + ASSERT_EQ(get(vec[6]), 2.64466613769531250e+02); + } } TEST(Json, Null) { Json json {JsonNull()}; - std::stringstream ss; + std::string ss; Json::Dump(json, &ss); - ASSERT_EQ(ss.str(), "null"); + ASSERT_EQ(ss, "null"); std::string null_input {R"null({"key": null })null"}; @@ -288,7 +331,7 @@ TEST(Json, AssigningObjects) { Json json_object { JsonObject() }; auto str = JsonString("1"); auto& k = json_object["1"]; - k = str; + k = std::move(str); auto& m = json_object["1"]; std::string value = get(m); ASSERT_EQ(value, "1"); @@ -365,15 +408,56 @@ TEST(Json, LoadDump) { dmlc::TemporaryDirectory tempdir; auto const& path = tempdir.path + "test_model_dump"; - std::ofstream fout (path); - Json::Dump(origin, &fout); - fout.close(); + std::string out; + Json::Dump(origin, &out); + + std::ofstream fout(path); + ASSERT_TRUE(fout); + fout << out << std::flush; std::string new_buffer = common::LoadSequentialFile(path); - Json load_back {Json::Load(StringView(new_buffer.c_str(), new_buffer.size()))}; - ASSERT_EQ(load_back, origin) << ori_buffer << "\n\n---------------\n\n" - << new_buffer; + Json load_back {Json::Load(StringView(new_buffer.c_str(), new_buffer.size()))}; + ASSERT_EQ(load_back, origin); +} + +TEST(Json, Invalid) { + { + std::string str = "}"; + bool has_thrown = false; + try { + Json load{Json::Load(StringView(str.c_str(), str.size()))}; + } catch (dmlc::Error const &e) { + std::string msg = e.what(); + ASSERT_NE(msg.find("Unknown"), std::string::npos); + has_thrown = true; + }; + ASSERT_TRUE(has_thrown); + } + { + std::string str = R"json({foo)json"; + bool has_thrown = false; + try { + Json load{Json::Load(StringView(str.c_str(), str.size()))}; + } catch (dmlc::Error const &e) { + std::string msg = e.what(); + ASSERT_NE(msg.find("position: 1"), std::string::npos); + has_thrown = true; + }; + ASSERT_TRUE(has_thrown); + } + { + std::string str = R"json({"foo")json"; + bool has_thrown = false; + try { + Json load{Json::Load(StringView(str.c_str(), str.size()))}; + } catch (dmlc::Error const &e) { + std::string msg = e.what(); + ASSERT_NE(msg.find("EOF"), std::string::npos); + has_thrown = true; + }; + ASSERT_TRUE(has_thrown); + } } // For now Json is quite ignorance about unicode. @@ -383,10 +467,9 @@ TEST(Json, CopyUnicode) { )json"; Json loaded {Json::Load(StringView{json_str.c_str(), json_str.size()})}; - std::stringstream ss_1; - Json::Dump(loaded, &ss_1); + std::string dumped_string; + Json::Dump(loaded, &dumped_string); - std::string dumped_string = ss_1.str(); ASSERT_NE(dumped_string.find("\\u20ac"), std::string::npos); } @@ -406,6 +489,15 @@ TEST(Json, WrongCasts) { } } +TEST(Json, Integer) { + for (int64_t i = 1; i < 10000; i *= 10) { + auto ten = Json{Integer{i}}; + std::string str; + Json::Dump(ten, &str); + ASSERT_EQ(str, std::to_string(i)); + } +} + TEST(Json, IntVSFloat) { // If integer is parsed as float, calling `get()' will throw. { @@ -432,4 +524,31 @@ TEST(Json, IntVSFloat) { ASSERT_EQ(ptr, 2503595760); } } + +TEST(Json, RoundTrip) { + uint32_t i = 0; + SimpleLCG rng; + SimpleRealUniformDistribution dist(1.0f, 4096.0f); + + while (i <= std::numeric_limits::max()) { + float f; + std::memcpy(&f, &i, sizeof(f)); + + Json jf { f }; + std::string str; + Json::Dump(jf, &str); + auto loaded = Json::Load({str.c_str(), str.size()}); + if (XGBOOST_EXPECT(std::isnan(f), false)) { + ASSERT_TRUE(std::isnan(get(loaded))); + } else { + ASSERT_EQ(get(loaded), f); + } + + auto t = i; + i+= static_cast(dist(&rng)); + if (i < t) { + break; + } + } +} } // namespace xgboost diff --git a/tests/cpp/common/test_span.cu b/tests/cpp/common/test_span.cu index 9aa0b8d53..00e00d4f4 100644 --- a/tests/cpp/common/test_span.cu +++ b/tests/cpp/common/test_span.cu @@ -50,7 +50,7 @@ __global__ void TestFromOtherKernel(Span span) { } } // Test converting different T - __global__ void TestFromOtherKernelConst(Span span) { +__global__ void TestFromOtherKernelConst(Span span) { // don't get optimized out size_t idx = threadIdx.x + blockIdx.x * blockDim.x; diff --git a/tests/cpp/data/test_device_adapter.cu b/tests/cpp/data/test_device_adapter.cu index 181304b3a..34c8e93b7 100644 --- a/tests/cpp/data/test_device_adapter.cu +++ b/tests/cpp/data/test_device_adapter.cu @@ -23,9 +23,8 @@ void TestCudfAdapter() Json column_arr {columns}; - std::stringstream ss; - Json::Dump(column_arr, &ss); - std::string str = ss.str(); + std::string str; + Json::Dump(column_arr, &str); data::CudfAdapter adapter(str); diff --git a/tests/cpp/data/test_device_dmatrix.cu b/tests/cpp/data/test_device_dmatrix.cu index db29cc574..7e0574c2e 100644 --- a/tests/cpp/data/test_device_dmatrix.cu +++ b/tests/cpp/data/test_device_dmatrix.cu @@ -78,9 +78,8 @@ TEST(DeviceDMatrix, ColumnMajor) { Json column_arr{columns}; - std::stringstream ss; - Json::Dump(column_arr, &ss); - std::string str = ss.str(); + std::string str; + Json::Dump(column_arr, &str); data::CudfAdapter adapter(str); data::DeviceDMatrix dmat(&adapter, std::numeric_limits::quiet_NaN(), diff --git a/tests/cpp/data/test_metainfo.cu b/tests/cpp/data/test_metainfo.cu index 2685cc3eb..23cb0f243 100644 --- a/tests/cpp/data/test_metainfo.cu +++ b/tests/cpp/data/test_metainfo.cu @@ -32,9 +32,8 @@ std::string PrepareData(std::string typestr, thrust::device_vector* out, cons column["data"] = j_data; Json array(std::vector{column}); - std::stringstream ss; - Json::Dump(array, &ss); - std::string str = ss.str(); + std::string str; + Json::Dump(array, &str); return str; } diff --git a/tests/cpp/data/test_simple_dmatrix.cu b/tests/cpp/data/test_simple_dmatrix.cu index ba492c34a..aff977bd2 100644 --- a/tests/cpp/data/test_simple_dmatrix.cu +++ b/tests/cpp/data/test_simple_dmatrix.cu @@ -22,9 +22,8 @@ TEST(SimpleDMatrix, FromColumnarDenseBasic) { Json column_arr{columns}; - std::stringstream ss; - Json::Dump(column_arr, &ss); - std::string str = ss.str(); + std::string str; + Json::Dump(column_arr, &str); data::CudfAdapter adapter(str); data::SimpleDMatrix dmat(&adapter, std::numeric_limits::quiet_NaN(), @@ -59,9 +58,8 @@ TEST(SimpleDMatrix, FromColumnarDense) { Json column_arr{columns}; - std::stringstream ss; - Json::Dump(column_arr, &ss); - std::string str = ss.str(); + std::string str; + Json::Dump(column_arr, &str); // no missing value { @@ -156,9 +154,9 @@ TEST(SimpleDMatrix, FromColumnarWithEmptyRows) { } Json column_arr{Array(v_columns)}; - std::stringstream ss; - Json::Dump(column_arr, &ss); - std::string str = ss.str(); + std::string str; + Json::Dump(column_arr, &str); + data::CudfAdapter adapter(str); data::SimpleDMatrix dmat(&adapter, std::numeric_limits::quiet_NaN(), -1); @@ -244,9 +242,8 @@ TEST(SimpleCSRSource, FromColumnarSparse) { Json column_arr {Array(j_columns)}; - std::stringstream ss; - Json::Dump(column_arr, &ss); - std::string str = ss.str(); + std::string str; + Json::Dump(column_arr, &str); { data::CudfAdapter adapter(str); @@ -296,9 +293,8 @@ TEST(SimpleDMatrix, FromColumnarSparseBasic) { Json column_arr{columns}; - std::stringstream ss; - Json::Dump(column_arr, &ss); - std::string str = ss.str(); + std::string str; + Json::Dump(column_arr, &str); data::CudfAdapter adapter(str); data::SimpleDMatrix dmat(&adapter, std::numeric_limits::quiet_NaN(), @@ -324,9 +320,8 @@ TEST(SimpleDMatrix, FromCupy){ int cols = 10; thrust::device_vector< float> data(rows*cols); auto json_array_interface = Generate2dArrayInterface(rows, cols, "::quiet_NaN(); data[2] = std::numeric_limits::quiet_NaN(); - std::stringstream ss; - Json::Dump(json_array_interface, &ss); - std::string str = ss.str(); + std::string str; + Json::Dump(json_array_interface, &str); data::CupyAdapter adapter(str); data::SimpleDMatrix dmat(&adapter, -1, 1); EXPECT_EQ(dmat.Info().num_col_, cols);