Implement fast number serialization routines. (#5772)
* Implement ryu algorithm. * Implement integer printing. * Full coverage roundtrip test.
This commit is contained in:
parent
7c3a168ffd
commit
38ee514787
@ -69,6 +69,7 @@
|
||||
#include "../src/learner.cc"
|
||||
#include "../src/logging.cc"
|
||||
#include "../src/common/common.cc"
|
||||
#include "../src/common/charconv.cc"
|
||||
#include "../src/common/timer.cc"
|
||||
#include "../src/common/host_device_vector.cc"
|
||||
#include "../src/common/hist_util.cc"
|
||||
|
||||
@ -1,18 +1,18 @@
|
||||
/*!
|
||||
* Copyright (c) by Contributors 2019
|
||||
* Copyright (c) by XGBoost Contributors 2019-2020
|
||||
*/
|
||||
#ifndef XGBOOST_JSON_H_
|
||||
#define XGBOOST_JSON_H_
|
||||
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/parameter.h>
|
||||
#include <string>
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -331,13 +331,7 @@ class Json {
|
||||
static Json Load(StringView str);
|
||||
/*! \brief Pass your own JsonReader. */
|
||||
static Json Load(JsonReader* reader);
|
||||
/*! \brief Dump json into stream. */
|
||||
static void Dump(Json json, std::ostream* stream,
|
||||
bool pretty = ConsoleLogger::ShouldLog(
|
||||
ConsoleLogger::LogVerbosity::kDebug));
|
||||
static void Dump(Json json, std::string* out,
|
||||
bool pretty = ConsoleLogger::ShouldLog(
|
||||
ConsoleLogger::LogVerbosity::kDebug));
|
||||
static void Dump(Json json, std::string* out);
|
||||
|
||||
Json() : ptr_{new JsonNull} {}
|
||||
|
||||
|
||||
@ -4,7 +4,9 @@
|
||||
#ifndef XGBOOST_JSON_IO_H_
|
||||
#define XGBOOST_JSON_IO_H_
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/base.h>
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <cinttypes>
|
||||
@ -15,20 +17,6 @@
|
||||
#include <locale>
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
template <typename Allocator>
|
||||
class FixedPrecisionStreamContainer : public std::basic_stringstream<
|
||||
char, std::char_traits<char>, Allocator> {
|
||||
public:
|
||||
FixedPrecisionStreamContainer() {
|
||||
this->precision(std::numeric_limits<double>::max_digits10);
|
||||
this->imbue(std::locale("C"));
|
||||
this->setf(std::ios::scientific);
|
||||
}
|
||||
};
|
||||
|
||||
using FixedPrecisionStream = FixedPrecisionStreamContainer<std::allocator<char>>;
|
||||
|
||||
/*
|
||||
* \brief A json reader, currently error checking and utf-8 is not fully supported.
|
||||
*/
|
||||
@ -45,13 +33,11 @@ class JsonReader {
|
||||
SourceLocation() = default;
|
||||
size_t Pos() const { return pos_; }
|
||||
|
||||
SourceLocation& Forward() {
|
||||
void Forward() {
|
||||
pos_++;
|
||||
return *this;
|
||||
}
|
||||
SourceLocation& Forward(uint32_t n) {
|
||||
void Forward(uint32_t n) {
|
||||
pos_ += n;
|
||||
return *this;
|
||||
}
|
||||
} cursor_;
|
||||
|
||||
@ -77,14 +63,17 @@ class JsonReader {
|
||||
return ch;
|
||||
}
|
||||
|
||||
/* \brief Skip spaces and consume next character. */
|
||||
char GetNextNonSpaceChar() {
|
||||
SkipSpaces();
|
||||
return GetNextChar();
|
||||
}
|
||||
|
||||
char GetChar(char c) {
|
||||
char result = GetNextNonSpaceChar();
|
||||
if (result != c) { Expect(c, result); }
|
||||
/* \brief Consume next character without first skipping empty space, throw when the next
|
||||
* character is not the expected one.
|
||||
*/
|
||||
char GetConsecutiveChar(char expected_char) {
|
||||
char result = GetNextChar();
|
||||
if (XGBOOST_EXPECT(result != expected_char, false)) { Expect(expected_char, result); }
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -95,7 +84,11 @@ class JsonReader {
|
||||
std::string msg = "Expecting: \"";
|
||||
msg += c;
|
||||
msg += "\", got: \"";
|
||||
msg += std::string {got} + " \"";
|
||||
if (got == -1) {
|
||||
msg += "EOF\"";
|
||||
} else {
|
||||
msg += std::to_string(got) + " \"";
|
||||
}
|
||||
Error(msg);
|
||||
}
|
||||
|
||||
@ -119,38 +112,16 @@ class JsonReader {
|
||||
|
||||
class JsonWriter {
|
||||
static constexpr size_t kIndentSize = 2;
|
||||
FixedPrecisionStream convertor_;
|
||||
|
||||
size_t n_spaces_;
|
||||
std::ostream* stream_;
|
||||
bool pretty_;
|
||||
std::vector<char>* stream_;
|
||||
|
||||
public:
|
||||
JsonWriter(std::ostream* stream, bool pretty) :
|
||||
n_spaces_{0}, stream_{stream}, pretty_{pretty} {}
|
||||
explicit JsonWriter(std::vector<char>* stream) :
|
||||
n_spaces_{0}, stream_{stream} {}
|
||||
|
||||
virtual ~JsonWriter() = default;
|
||||
|
||||
void NewLine() {
|
||||
if (pretty_) {
|
||||
*stream_ << u8"\n" << std::string(n_spaces_, ' ');
|
||||
}
|
||||
}
|
||||
|
||||
void BeginIndent() {
|
||||
n_spaces_ += kIndentSize;
|
||||
}
|
||||
void EndIndent() {
|
||||
n_spaces_ -= kIndentSize;
|
||||
}
|
||||
|
||||
void Write(std::string str) {
|
||||
*stream_ << str;
|
||||
}
|
||||
void Write(StringView str) {
|
||||
stream_->write(str.c_str(), str.size());
|
||||
}
|
||||
|
||||
void Save(Json json);
|
||||
|
||||
virtual void Visit(JsonArray const* arr);
|
||||
|
||||
942
src/common/charconv.cc
Normal file
942
src/common/charconv.cc
Normal file
@ -0,0 +1,942 @@
|
||||
/*!
|
||||
* Copyright 2020 by XGBoost Contributors
|
||||
*
|
||||
* \brief An implemenation of Ryu algorithm:
|
||||
*
|
||||
* https://dl.acm.org/citation.cfm?id=3192369
|
||||
*
|
||||
* The code is adopted from original (half) c implementation:
|
||||
* https://github.com/ulfjack/ryu.git with some more comments and tidying. License is
|
||||
* attached below.
|
||||
*
|
||||
* Copyright 2018 Ulf Adams
|
||||
*
|
||||
* The contents of this file may be used under the terms of the Apache License,
|
||||
* Version 2.0.
|
||||
*
|
||||
* (See accompanying file LICENSE-Apache or copy at
|
||||
* http: *www.apache.org/licenses/LICENSE-2.0)
|
||||
*
|
||||
* Alternatively, the contents of this file may be used under the terms of
|
||||
* the Boost Software License, Version 1.0.
|
||||
* (See accompanying file LICENSE-Boost or copy at
|
||||
* https://www.boost.org/LICENSE_1_0.txt)
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, this software
|
||||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cinttypes>
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "charconv.h"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#include <intrin.h>
|
||||
namespace {
|
||||
inline int32_t __builtin_clzll(uint64_t x) {
|
||||
return static_cast<int32_t>(__lzcnt64(x));
|
||||
}
|
||||
} // anonymous namespace
|
||||
#endif
|
||||
|
||||
/*
|
||||
* We did some cleanup from the original implementation instead of doing line to line
|
||||
* port.
|
||||
*
|
||||
* The basic concept of floating rounding is, for a floating point number, we need to
|
||||
* convert base2 to base10. During which we need to implement correct rounding. Hence on
|
||||
* base2 we have:
|
||||
*
|
||||
* {low, value, high}
|
||||
*
|
||||
* 3 values, representing round down, no rounding, and round up. In the original
|
||||
* implementation and paper, variables representing these 3 values are typically postfixed
|
||||
* with m, r, p like {vr, vm, vp}. Here we name them more verbosely.
|
||||
*/
|
||||
|
||||
namespace xgboost {
|
||||
namespace detail {
|
||||
static constexpr char kItoaLut[200] = {
|
||||
'0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0',
|
||||
'7', '0', '8', '0', '9', '1', '0', '1', '1', '1', '2', '1', '3', '1', '4',
|
||||
'1', '5', '1', '6', '1', '7', '1', '8', '1', '9', '2', '0', '2', '1', '2',
|
||||
'2', '2', '3', '2', '4', '2', '5', '2', '6', '2', '7', '2', '8', '2', '9',
|
||||
'3', '0', '3', '1', '3', '2', '3', '3', '3', '4', '3', '5', '3', '6', '3',
|
||||
'7', '3', '8', '3', '9', '4', '0', '4', '1', '4', '2', '4', '3', '4', '4',
|
||||
'4', '5', '4', '6', '4', '7', '4', '8', '4', '9', '5', '0', '5', '1', '5',
|
||||
'2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5', '9',
|
||||
'6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6',
|
||||
'7', '6', '8', '6', '9', '7', '0', '7', '1', '7', '2', '7', '3', '7', '4',
|
||||
'7', '5', '7', '6', '7', '7', '7', '8', '7', '9', '8', '0', '8', '1', '8',
|
||||
'2', '8', '3', '8', '4', '8', '5', '8', '6', '8', '7', '8', '8', '8', '9',
|
||||
'9', '0', '9', '1', '9', '2', '9', '3', '9', '4', '9', '5', '9', '6', '9',
|
||||
'7', '9', '8', '9', '9'};
|
||||
|
||||
constexpr uint32_t Tens(uint32_t n) { return n == 1 ? 10 : (Tens(n - 1) * 10); }
|
||||
|
||||
struct UnsignedFloatBase2;
|
||||
|
||||
struct UnsignedFloatBase10 {
|
||||
uint32_t mantissa;
|
||||
// Decimal exponent's range is -45 to 38
|
||||
// inclusive, and can fit in a short if needed.
|
||||
int32_t exponent;
|
||||
};
|
||||
|
||||
template <typename To, typename From>
|
||||
To BitCast(From&& from) {
|
||||
static_assert(sizeof(From) == sizeof(To), "Bit cast doesn't change output size.");
|
||||
To t;
|
||||
std::memcpy(&t, &from, sizeof(To));
|
||||
return t;
|
||||
}
|
||||
|
||||
struct IEEE754 {
|
||||
static constexpr uint32_t kFloatMantissaBits = 23;
|
||||
static constexpr uint32_t kFloatBias = 127;
|
||||
static constexpr uint32_t kFloatExponentBits = 8;
|
||||
|
||||
static void Decode(float f, UnsignedFloatBase2* uf, bool* signbit);
|
||||
static float Encode(UnsignedFloatBase2 const& uf, bool signbit);
|
||||
|
||||
static float Infinity(bool sign) {
|
||||
uint32_t f =
|
||||
((static_cast<uint32_t>(sign))
|
||||
<< (IEEE754::kFloatExponentBits + IEEE754::kFloatMantissaBits)) |
|
||||
(0xffu << IEEE754::kFloatMantissaBits);
|
||||
float result = BitCast<float>(f);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
struct UnsignedFloatBase2 {
|
||||
uint32_t mantissa;
|
||||
// Decimal exponent's range is -45 to 38
|
||||
// inclusive, and can fit in a short if needed.
|
||||
uint32_t exponent;
|
||||
|
||||
bool Infinite() const {
|
||||
return exponent == ((1u << IEEE754::kFloatExponentBits) - 1u);
|
||||
}
|
||||
bool Zero() const {
|
||||
return mantissa == 0 && exponent == 0;
|
||||
}
|
||||
};
|
||||
|
||||
inline void IEEE754::Decode(float f, UnsignedFloatBase2 *uf, bool *signbit) {
|
||||
auto bits = BitCast<uint32_t>(f);
|
||||
// Decode bits into sign, mantissa, and exponent.
|
||||
*signbit = std::signbit(f);
|
||||
uf->mantissa = bits & ((1u << kFloatMantissaBits) - 1);
|
||||
uf->exponent = (bits >> IEEE754::kFloatMantissaBits) &
|
||||
((1u << IEEE754::kFloatExponentBits) - 1); // remove signbit
|
||||
}
|
||||
|
||||
inline float IEEE754::Encode(UnsignedFloatBase2 const &uf, bool signbit) {
|
||||
uint32_t f =
|
||||
((((static_cast<uint32_t>(signbit)) << IEEE754::kFloatExponentBits) |
|
||||
static_cast<uint32_t>(uf.exponent))
|
||||
<< IEEE754::kFloatMantissaBits) |
|
||||
uf.mantissa;
|
||||
return BitCast<float>(f);
|
||||
}
|
||||
|
||||
// Represents the interval of information-preserving outputs.
|
||||
struct MantissaInteval {
|
||||
int32_t exponent;
|
||||
// low: smaller half way point
|
||||
uint32_t mantissa_low;
|
||||
// correct: f
|
||||
uint32_t mantissa_correct;
|
||||
// high: larger half way point
|
||||
uint32_t mantissa_high;
|
||||
};
|
||||
|
||||
struct RyuPowLogUtils {
|
||||
// This table is generated by PrintFloatLookupTable from ryu. We adopted only the float
|
||||
// 32 table instead of double full table.
|
||||
// f2s_full_table.h
|
||||
uint32_t constexpr static kFloatPow5InvBitcount = 59;
|
||||
static constexpr uint64_t kFloatPow5InvSplit[55] = {
|
||||
576460752303423489u, 461168601842738791u, 368934881474191033u,
|
||||
295147905179352826u, 472236648286964522u, 377789318629571618u,
|
||||
302231454903657294u, 483570327845851670u, 386856262276681336u,
|
||||
309485009821345069u, 495176015714152110u, 396140812571321688u,
|
||||
316912650057057351u, 507060240091291761u, 405648192073033409u,
|
||||
324518553658426727u, 519229685853482763u, 415383748682786211u,
|
||||
332306998946228969u, 531691198313966350u, 425352958651173080u,
|
||||
340282366920938464u, 544451787073501542u, 435561429658801234u,
|
||||
348449143727040987u, 557518629963265579u, 446014903970612463u,
|
||||
356811923176489971u, 570899077082383953u, 456719261665907162u,
|
||||
365375409332725730u, 292300327466180584u, 467680523945888934u,
|
||||
374144419156711148u, 299315535325368918u, 478904856520590269u,
|
||||
383123885216472215u, 306499108173177772u, 490398573077084435u,
|
||||
392318858461667548u, 313855086769334039u, 502168138830934462u,
|
||||
401734511064747569u, 321387608851798056u, 514220174162876889u,
|
||||
411376139330301511u, 329100911464241209u, 526561458342785934u,
|
||||
421249166674228747u, 336999333339382998u, 539198933343012796u,
|
||||
431359146674410237u, 345087317339528190u, 552139707743245103u,
|
||||
441711766194596083u};
|
||||
|
||||
uint32_t constexpr static kFloatPow5Bitcount = 61;
|
||||
static constexpr uint64_t kFloatPow5Split[47] = {
|
||||
1152921504606846976u, 1441151880758558720u, 1801439850948198400u,
|
||||
2251799813685248000u, 1407374883553280000u, 1759218604441600000u,
|
||||
2199023255552000000u, 1374389534720000000u, 1717986918400000000u,
|
||||
2147483648000000000u, 1342177280000000000u, 1677721600000000000u,
|
||||
2097152000000000000u, 1310720000000000000u, 1638400000000000000u,
|
||||
2048000000000000000u, 1280000000000000000u, 1600000000000000000u,
|
||||
2000000000000000000u, 1250000000000000000u, 1562500000000000000u,
|
||||
1953125000000000000u, 1220703125000000000u, 1525878906250000000u,
|
||||
1907348632812500000u, 1192092895507812500u, 1490116119384765625u,
|
||||
1862645149230957031u, 1164153218269348144u, 1455191522836685180u,
|
||||
1818989403545856475u, 2273736754432320594u, 1421085471520200371u,
|
||||
1776356839400250464u, 2220446049250313080u, 1387778780781445675u,
|
||||
1734723475976807094u, 2168404344971008868u, 1355252715606880542u,
|
||||
1694065894508600678u, 2117582368135750847u, 1323488980084844279u,
|
||||
1654361225106055349u, 2067951531382569187u, 1292469707114105741u,
|
||||
1615587133892632177u, 2019483917365790221u};
|
||||
|
||||
static uint32_t Pow5Factor(uint32_t value) noexcept(true) {
|
||||
uint32_t count = 0;
|
||||
for (;;) {
|
||||
const uint32_t q = value / 5;
|
||||
const uint32_t r = value % 5;
|
||||
if (r != 0) {
|
||||
break;
|
||||
}
|
||||
value = q;
|
||||
++count;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
// Returns true if value is divisible by 5^p.
|
||||
static bool MultipleOfPowerOf5(const uint32_t value, const uint32_t p) noexcept(true) {
|
||||
return Pow5Factor(value) >= p;
|
||||
}
|
||||
|
||||
// Returns true if value is divisible by 2^p.
|
||||
static bool MultipleOfPowerOf2(const uint32_t value, const uint32_t p) noexcept(true) {
|
||||
#ifdef __GNUC__
|
||||
return static_cast<uint32_t>(__builtin_ctz(value)) >= p;
|
||||
#else
|
||||
return (value & ((1u << p) - 1)) == 0;
|
||||
#endif // __GNUC__
|
||||
}
|
||||
|
||||
// Returns e == 0 ? 1 : ceil(log_2(5^e)).
|
||||
static uint32_t Pow5Bits(const int32_t e) noexcept(true) {
|
||||
return static_cast<uint32_t>(((e * 163391164108059ull) >> 46) + 1);
|
||||
}
|
||||
|
||||
static int32_t Log2Pow5(const int32_t e) {
|
||||
// This approximation works up to the point that the multiplication
|
||||
// overflows at e = 3529. If the multiplication were done in 64 bits, it
|
||||
// would fail at 5^4004 which is just greater than 2^9297.
|
||||
assert(e >= 0);
|
||||
assert(e <= 3528);
|
||||
return static_cast<int32_t>(((static_cast<uint32_t>(e)) * 1217359) >> 19);
|
||||
}
|
||||
|
||||
static int32_t CeilLog2Pow5(const int32_t e) {
|
||||
return RyuPowLogUtils::Log2Pow5(e) + 1;
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief Multiply 32-bit and 64-bit -> 128 bit, then access the higher bits.
|
||||
*/
|
||||
static uint32_t MulShift(const uint32_t x, const uint64_t y,
|
||||
const int32_t shift) noexcept(true) {
|
||||
// For 32-bit * 64-bit: x * y, it can be decomposed into:
|
||||
//
|
||||
// x * (y_high + y_low) = (x * y_high) + (x * y_low)
|
||||
//
|
||||
// For more general case 64-bit * 64-bit, see https://stackoverflow.com/a/1541458
|
||||
const uint32_t y_low = static_cast<uint32_t>(y);
|
||||
const uint32_t y_high = static_cast<uint32_t>(y >> 32);
|
||||
|
||||
const uint64_t low = static_cast<uint64_t>(x) * y_low;
|
||||
const uint64_t high = static_cast<uint64_t>(x) * y_high;
|
||||
|
||||
const uint64_t sum = (low >> 32) + high;
|
||||
const uint64_t shifted_sum = sum >> (shift - 32);
|
||||
|
||||
return static_cast<uint32_t>(shifted_sum);
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief floor(5^q/2*k) and shift by j
|
||||
*/
|
||||
static uint32_t MulPow5InvDivPow2(const uint32_t m, const uint32_t q,
|
||||
const int32_t j) noexcept(true) {
|
||||
return MulShift(m, kFloatPow5InvSplit[q], j);
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief floor(2^k/5^q) + 1 and shift by j
|
||||
*/
|
||||
static uint32_t MulPow5divPow2(const uint32_t m, const uint32_t i,
|
||||
const int32_t j) noexcept(true) {
|
||||
// clang-tidy makes false assumption that can lead to i >= 47, which is impossible.
|
||||
// Can be verified by enumerating all float32 values.
|
||||
return MulShift(m, kFloatPow5Split[i], j); // NOLINT
|
||||
}
|
||||
|
||||
static uint32_t FloorLog2(const uint64_t value) {
|
||||
return 63 - __builtin_clzll(value);
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief floor(e * log_10(2)).
|
||||
*/
|
||||
static uint32_t Log10Pow2(const int32_t e) noexcept(true) {
|
||||
// The first value this approximation fails for is 2^1651 which is just
|
||||
// greater than 10^297.
|
||||
assert(e >= 0);
|
||||
assert(e <= 1 << 15);
|
||||
return static_cast<uint32_t>((static_cast<uint64_t>(e) * 169464822037455ull) >> 49);
|
||||
}
|
||||
|
||||
// Returns floor(e * log_10(5)).
|
||||
static uint32_t Log10Pow5(const int32_t expoent) noexcept(true) {
|
||||
// The first value this approximation fails for is 5^2621 which is just
|
||||
// greater than 10^1832.
|
||||
assert(expoent >= 0);
|
||||
assert(expoent <= 1 << 15);
|
||||
return static_cast<uint32_t>(
|
||||
((static_cast<uint64_t>(expoent)) * 196742565691928ull) >> 48);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr uint64_t RyuPowLogUtils::kFloatPow5InvSplit[55];
|
||||
constexpr uint64_t RyuPowLogUtils::kFloatPow5Split[47];
|
||||
|
||||
class PowerBaseComputer {
|
||||
private:
|
||||
static uint8_t
|
||||
ToDecimalBase(bool const accept_bounds, uint32_t const mantissa_low_shift,
|
||||
MantissaInteval const base2, MantissaInteval *base10,
|
||||
bool *mantissa_low_is_trailing_zeros,
|
||||
bool *mantissa_out_is_trailing_zeros) noexcept(true) {
|
||||
uint8_t last_removed_digit = 0;
|
||||
if (base2.exponent >= 0) {
|
||||
const uint32_t q = RyuPowLogUtils::Log10Pow2(base2.exponent);
|
||||
base10->exponent = static_cast<int32_t>(q);
|
||||
const int32_t k = RyuPowLogUtils::kFloatPow5InvBitcount +
|
||||
RyuPowLogUtils::Pow5Bits(static_cast<int32_t>(q)) - 1;
|
||||
const int32_t i = -base2.exponent + static_cast<int32_t>(q) + k;
|
||||
base10->mantissa_low =
|
||||
RyuPowLogUtils::MulPow5InvDivPow2(base2.mantissa_low, q, i);
|
||||
base10->mantissa_correct =
|
||||
RyuPowLogUtils::MulPow5InvDivPow2(base2.mantissa_correct, q, i);
|
||||
base10->mantissa_high =
|
||||
RyuPowLogUtils::MulPow5InvDivPow2(base2.mantissa_high, q, i);
|
||||
|
||||
if (q != 0 &&
|
||||
(base10->mantissa_high - 1) / 10 <= base10->mantissa_low / 10) {
|
||||
// We need to know one removed digit even if we are not going to loop
|
||||
// below. We could use q = X - 1 above, except that would require 33
|
||||
// bits for the result, and we've found that 32-bit arithmetic is
|
||||
// faster even on 64-bit machines.
|
||||
const int32_t l =
|
||||
RyuPowLogUtils::kFloatPow5InvBitcount +
|
||||
RyuPowLogUtils::Pow5Bits(static_cast<int32_t>(q - 1)) - 1;
|
||||
last_removed_digit = static_cast<uint8_t>(
|
||||
RyuPowLogUtils::MulPow5InvDivPow2(
|
||||
base2.mantissa_correct, q - 1,
|
||||
-base2.exponent + static_cast<int32_t>(q) - 1 + l) %
|
||||
10);
|
||||
}
|
||||
if (q <= 9) {
|
||||
// The largest power of 5 that fits in 24 bits is 5^10, but q <= 9 seems to be
|
||||
// safe as well. Only one of mantissa_high, mantissa_correct, and mantissa_low can
|
||||
// be a multiple of 5, if any.
|
||||
if (base2.mantissa_correct % 5 == 0) {
|
||||
*mantissa_out_is_trailing_zeros =
|
||||
RyuPowLogUtils::MultipleOfPowerOf5(base2.mantissa_correct, q);
|
||||
} else if (accept_bounds) {
|
||||
*mantissa_low_is_trailing_zeros =
|
||||
RyuPowLogUtils::MultipleOfPowerOf5(base2.mantissa_low, q);
|
||||
} else {
|
||||
base10->mantissa_high -=
|
||||
RyuPowLogUtils::MultipleOfPowerOf5(base2.mantissa_high, q);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const uint32_t q = RyuPowLogUtils::Log10Pow5(-base2.exponent);
|
||||
base10->exponent = static_cast<int32_t>(q) + base2.exponent;
|
||||
const int32_t i = -base2.exponent - static_cast<int32_t>(q);
|
||||
const int32_t k =
|
||||
RyuPowLogUtils::Pow5Bits(i) - RyuPowLogUtils::kFloatPow5Bitcount;
|
||||
int32_t j = static_cast<int32_t>(q) - k;
|
||||
base10->mantissa_correct = RyuPowLogUtils::MulPow5divPow2(
|
||||
base2.mantissa_correct, static_cast<uint32_t>(i), j);
|
||||
base10->mantissa_high = RyuPowLogUtils::MulPow5divPow2(
|
||||
base2.mantissa_high, static_cast<uint32_t>(i), j);
|
||||
base10->mantissa_low = RyuPowLogUtils::MulPow5divPow2(
|
||||
base2.mantissa_low, static_cast<uint32_t>(i), j);
|
||||
|
||||
if (q != 0 &&
|
||||
(base10->mantissa_high - 1) / 10 <= base10->mantissa_low / 10) {
|
||||
j = static_cast<int32_t>(q) - 1 -
|
||||
(RyuPowLogUtils::Pow5Bits(i + 1) -
|
||||
RyuPowLogUtils::kFloatPow5Bitcount);
|
||||
last_removed_digit = static_cast<uint8_t>(
|
||||
RyuPowLogUtils::MulPow5divPow2(base2.mantissa_correct,
|
||||
static_cast<uint32_t>(i + 1), j) %
|
||||
10);
|
||||
}
|
||||
if (q <= 1) {
|
||||
// {mantissa_out, mantissa_out_high, mantissa_out_low} is trailing zeros if
|
||||
// {mantissa_correct,mantissa_high,mantissa_low} has at least q trailing 0
|
||||
// bits.mantissa_correct = 4 * m2, so it always has at least two trailing 0 bits.
|
||||
*mantissa_out_is_trailing_zeros = true;
|
||||
if (accept_bounds) {
|
||||
// mantissa_low = mantissa_correct - 1 - mantissa_low_shift, so it has 1
|
||||
// trailing 0 bit iff mmShift == 1.
|
||||
*mantissa_low_is_trailing_zeros = mantissa_low_shift == 1;
|
||||
} else {
|
||||
// mantissa_high = mantissa_correct + 2, so it always has at least one trailing
|
||||
// 0 bit.
|
||||
--base10->mantissa_high;
|
||||
}
|
||||
} else if (q < 31) {
|
||||
*mantissa_out_is_trailing_zeros =
|
||||
RyuPowLogUtils::MultipleOfPowerOf2(base2.mantissa_correct, q - 1);
|
||||
}
|
||||
}
|
||||
return last_removed_digit;
|
||||
}
|
||||
|
||||
/*
|
||||
* \brief A varient of extended euclidean GCD algorithm.
|
||||
*/
|
||||
static UnsignedFloatBase10
|
||||
ShortestRepresentation(bool mantissa_low_is_trailing_zeros,
|
||||
bool mantissa_out_is_trailing_zeros,
|
||||
uint8_t last_removed_digit, bool const accept_bounds,
|
||||
MantissaInteval base10) noexcept(true) {
|
||||
int32_t removed {0};
|
||||
uint32_t output {0};
|
||||
|
||||
if (mantissa_low_is_trailing_zeros || mantissa_out_is_trailing_zeros) {
|
||||
// General case, which happens rarely (~4.0%).
|
||||
while (base10.mantissa_high / 10 > base10.mantissa_low / 10) {
|
||||
mantissa_low_is_trailing_zeros &= base10.mantissa_low % 10 == 0;
|
||||
mantissa_out_is_trailing_zeros &= last_removed_digit == 0;
|
||||
last_removed_digit = static_cast<uint8_t>(base10.mantissa_correct % 10);
|
||||
base10.mantissa_correct /= 10;
|
||||
base10.mantissa_high /= 10;
|
||||
base10.mantissa_low /= 10;
|
||||
++removed;
|
||||
}
|
||||
|
||||
if (mantissa_low_is_trailing_zeros) {
|
||||
while (base10.mantissa_low % 10 == 0) {
|
||||
mantissa_out_is_trailing_zeros &= last_removed_digit == 0;
|
||||
last_removed_digit = static_cast<uint8_t>(base10.mantissa_correct % 10);
|
||||
base10.mantissa_correct /= 10;
|
||||
base10.mantissa_high /= 10;
|
||||
base10.mantissa_low /= 10;
|
||||
++removed;
|
||||
}
|
||||
}
|
||||
|
||||
if (mantissa_out_is_trailing_zeros && last_removed_digit == 5 &&
|
||||
base10.mantissa_correct % 2 == 0) {
|
||||
// Round even if the exact number is .....50..0.
|
||||
last_removed_digit = 4;
|
||||
}
|
||||
// We need to take mantissa_out + 1 if mantissa_out is outside bounds or we need to
|
||||
// round up.
|
||||
output = base10.mantissa_correct +
|
||||
((base10.mantissa_correct == base10.mantissa_low &&
|
||||
(!accept_bounds || !mantissa_low_is_trailing_zeros)) ||
|
||||
last_removed_digit >= 5);
|
||||
} else {
|
||||
// Specialized for the common case (~96.0%). Percentages below are
|
||||
// relative to this. Loop iterations below (approximately): 0: 13.6%,
|
||||
// 1: 70.7%, 2: 14.1%, 3: 1.39%, 4: 0.14%, 5+: 0.01%
|
||||
while (base10.mantissa_high / 10 > base10.mantissa_low / 10) {
|
||||
last_removed_digit = static_cast<uint8_t>(base10.mantissa_correct % 10);
|
||||
base10.mantissa_correct /= 10;
|
||||
base10.mantissa_high /= 10;
|
||||
base10.mantissa_low /= 10;
|
||||
++removed;
|
||||
}
|
||||
|
||||
// We need to take mantissa_out + 1 if mantissa_out is outside bounds or we need to
|
||||
// round up.
|
||||
output = base10.mantissa_correct +
|
||||
(base10.mantissa_correct == base10.mantissa_low ||
|
||||
last_removed_digit >= 5);
|
||||
}
|
||||
const int32_t exp = base10.exponent + removed;
|
||||
|
||||
UnsignedFloatBase10 fd;
|
||||
fd.exponent = exp;
|
||||
fd.mantissa = output;
|
||||
return fd;
|
||||
}
|
||||
|
||||
public:
|
||||
static UnsignedFloatBase10 Binary2Decimal(UnsignedFloatBase2 const f) noexcept(true) {
|
||||
MantissaInteval base2_range;
|
||||
uint32_t mantissa_base2;
|
||||
if (f.exponent == 0) {
|
||||
// We subtract 2 so that the bounds computation has 2 additional bits.
|
||||
base2_range.exponent = static_cast<int32_t>(1) -
|
||||
static_cast<int32_t>(IEEE754::kFloatBias) -
|
||||
static_cast<int32_t>(IEEE754::kFloatMantissaBits) -
|
||||
static_cast<int32_t>(2);
|
||||
static_assert(static_cast<int32_t>(1) -
|
||||
static_cast<int32_t>(IEEE754::kFloatBias) -
|
||||
static_cast<int32_t>(IEEE754::kFloatMantissaBits) -
|
||||
static_cast<int32_t>(2) ==
|
||||
-151,
|
||||
"");
|
||||
mantissa_base2 = f.mantissa;
|
||||
} else {
|
||||
base2_range.exponent = static_cast<int32_t>(f.exponent) - IEEE754::kFloatBias -
|
||||
IEEE754::kFloatMantissaBits - 2;
|
||||
mantissa_base2 = (1u << IEEE754::kFloatMantissaBits) | f.mantissa;
|
||||
}
|
||||
const bool even = (mantissa_base2 & 1) == 0;
|
||||
const bool accept_bounds = even;
|
||||
|
||||
// Step 2: Determine the interval of valid decimal representations.
|
||||
base2_range.mantissa_correct = 4 * mantissa_base2;
|
||||
base2_range.mantissa_high = 4 * mantissa_base2 + 2;
|
||||
// Implicit bool -> int conversion. True is 1, false is 0.
|
||||
const uint32_t mantissa_low_shift = f.mantissa != 0 || f.exponent <= 1;
|
||||
base2_range.mantissa_low = 4 * mantissa_base2 - 1 - mantissa_low_shift;
|
||||
|
||||
// Step 3: Convert to a decimal power base using 64-bit arithmetic.
|
||||
MantissaInteval base10_range;
|
||||
bool mantissa_low_is_trailing_zeros = false;
|
||||
bool mantissa_out_is_trailing_zeros = false;
|
||||
auto last_removed_digit = PowerBaseComputer::ToDecimalBase(
|
||||
accept_bounds, mantissa_low_shift, base2_range, &base10_range,
|
||||
&mantissa_low_is_trailing_zeros, &mantissa_out_is_trailing_zeros);
|
||||
|
||||
// Step 4: Find the shortest decimal representation in the interval of valid
|
||||
// representations.
|
||||
auto out = ShortestRepresentation(mantissa_low_is_trailing_zeros,
|
||||
mantissa_out_is_trailing_zeros,
|
||||
last_removed_digit,
|
||||
accept_bounds, base10_range);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* \brief Print the floating point number in base 10.
|
||||
*/
|
||||
class RyuPrinter {
|
||||
private:
|
||||
static inline uint32_t OutputLength(const uint32_t v) noexcept(true) {
|
||||
// Function precondition: v is not a 10-digit number.
|
||||
// (f2s: 9 digits are sufficient for round-tripping.)
|
||||
// (d2fixed: We print 9-digit blocks.)
|
||||
static_assert(100000000 == Tens(8), "");
|
||||
assert(v < Tens(9));
|
||||
if (v >= Tens(8)) {
|
||||
return 9;
|
||||
}
|
||||
if (v >= Tens(7)) {
|
||||
return 8;
|
||||
}
|
||||
if (v >= Tens(6)) {
|
||||
return 7;
|
||||
}
|
||||
if (v >= Tens(5)) {
|
||||
return 6;
|
||||
}
|
||||
if (v >= Tens(4)) {
|
||||
return 5;
|
||||
}
|
||||
if (v >= Tens(3)) {
|
||||
return 4;
|
||||
}
|
||||
if (v >= Tens(2)) {
|
||||
return 3;
|
||||
}
|
||||
if (v >= Tens(1)) {
|
||||
return 2;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
public:
|
||||
static int32_t PrintBase10Float(UnsignedFloatBase10 v, const bool sign,
|
||||
char *const result) noexcept(true) {
|
||||
// Step 5: Print the decimal representation.
|
||||
int index = 0;
|
||||
if (sign) {
|
||||
result[index++] = '-';
|
||||
}
|
||||
|
||||
uint32_t output = v.mantissa;
|
||||
const uint32_t out_length = OutputLength(output);
|
||||
|
||||
// Print the decimal digits.
|
||||
// The following code is equivalent to:
|
||||
// for (uint32_t i = 0; i < olength - 1; ++i) {
|
||||
// const uint32_t c = output % 10; output /= 10;
|
||||
// result[index + olength - i] = (char) ('0' + c);
|
||||
// }
|
||||
// result[index] = '0' + output % 10;
|
||||
uint32_t i = 0;
|
||||
while (output >= Tens(4)) {
|
||||
const uint32_t c = output % Tens(4);
|
||||
output /= Tens(4);
|
||||
const uint32_t c0 = (c % 100) << 1;
|
||||
const uint32_t c1 = (c / 100) << 1;
|
||||
// This is used to speed up decimal digit generation by copying
|
||||
// pairs of digits into the final output.
|
||||
std::memcpy(result + index + out_length - i - 1, kItoaLut + c0, 2);
|
||||
std::memcpy(result + index + out_length - i - 3, kItoaLut + c1, 2);
|
||||
i += 4;
|
||||
}
|
||||
if (output >= 100) {
|
||||
const uint32_t c = (output % 100) << 1;
|
||||
output /= 100;
|
||||
std::memcpy(result + index + out_length - i - 1, kItoaLut + c, 2);
|
||||
i += 2;
|
||||
}
|
||||
if (output >= 10) {
|
||||
const uint32_t c = output << 1;
|
||||
// We can't use std::memcpy here: the decimal dot goes between these two
|
||||
// digits.
|
||||
result[index + out_length - i] = kItoaLut[c + 1];
|
||||
result[index] = kItoaLut[c];
|
||||
} else {
|
||||
result[index] = static_cast<char>('0' + output);
|
||||
}
|
||||
|
||||
// Print decimal point if needed.
|
||||
if (out_length > 1) {
|
||||
result[index + 1] = '.';
|
||||
index += out_length + 1;
|
||||
} else {
|
||||
++index;
|
||||
}
|
||||
|
||||
// Print the exponent.
|
||||
result[index++] = 'E';
|
||||
int32_t exp = v.exponent + static_cast<int32_t>(out_length) - 1;
|
||||
if (exp < 0) {
|
||||
result[index++] = '-';
|
||||
exp = -exp;
|
||||
}
|
||||
|
||||
if (exp >= 10) {
|
||||
std::memcpy(result + index, kItoaLut + 2 * exp, 2);
|
||||
index += 2;
|
||||
} else {
|
||||
result[index++] = static_cast<char>('0' + exp);
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
static int32_t PrintSpecialFloat(const bool sign, UnsignedFloatBase2 f,
|
||||
char *const result) noexcept(true) {
|
||||
if (f.mantissa) {
|
||||
std::memcpy(result, u8"NaN", 3);
|
||||
return 3;
|
||||
}
|
||||
if (sign) {
|
||||
result[0] = '-';
|
||||
}
|
||||
if (f.exponent) {
|
||||
std::memcpy(result + sign, u8"Infinity", 8);
|
||||
return sign + 8;
|
||||
}
|
||||
std::memcpy(result + sign, u8"0E0", 3);
|
||||
return sign + 3;
|
||||
}
|
||||
};
|
||||
|
||||
int32_t ToCharsFloatImpl(float f, char * const result) {
|
||||
// Step 1: Decode the floating-point number, and unify normalized and
|
||||
// subnormal cases.
|
||||
UnsignedFloatBase2 uf32;
|
||||
bool sign;
|
||||
IEEE754::Decode(f, &uf32, &sign);
|
||||
|
||||
// Case distinction; exit early for the easy cases.
|
||||
if (uf32.Infinite() || uf32.Zero()) {
|
||||
return RyuPrinter::PrintSpecialFloat(sign, uf32, result);
|
||||
}
|
||||
|
||||
const UnsignedFloatBase10 v = PowerBaseComputer::Binary2Decimal(uf32);
|
||||
const auto index = RyuPrinter::PrintBase10Float(v, sign, result);
|
||||
return index;
|
||||
}
|
||||
|
||||
|
||||
// ====================== Integer ==================
|
||||
|
||||
// This is an implementation for base 10 inspired by the one in libstdc++v3. The general
|
||||
// scheme is by decomposing the value into multiple combination of base (which is 10) by
|
||||
// mod, until the value is lesser than 10, then last char is just char '0' (ascii 48) plus
|
||||
// that value. Other popular implementations can be found in RapidJson and libc++ (in
|
||||
// llvm-project), which uses the same general work flow with the same look up table, but
|
||||
// probably with better performance as they are more complicated.
|
||||
void ItoaUnsignedImpl(char *first, uint32_t length, uint64_t value) {
|
||||
uint32_t position = length - 1;
|
||||
while (value >= Tens(2)) {
|
||||
auto const num = (value % Tens(2)) * 2;
|
||||
value /= Tens(2);
|
||||
first[position] = kItoaLut[num + 1];
|
||||
first[position - 1] = kItoaLut[num];
|
||||
position -= 2;
|
||||
}
|
||||
if (value >= 10) {
|
||||
auto const num = value * 2;
|
||||
first[0] = kItoaLut[num];
|
||||
first[1] = kItoaLut[num + 1];
|
||||
} else {
|
||||
first[0]= '0' + value;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr uint32_t ShortestDigit10Impl(uint64_t value, uint32_t n) {
|
||||
// Should trigger tail recursion optimization.
|
||||
return value < 10 ? n :
|
||||
(value < Tens(2) ? n + 1 :
|
||||
(value < Tens(3) ? n + 2 :
|
||||
(value < Tens(4) ? n + 3 :
|
||||
ShortestDigit10Impl(value / Tens(4), n + 4))));
|
||||
}
|
||||
|
||||
constexpr uint32_t ShortestDigit10(uint64_t value) {
|
||||
return ShortestDigit10Impl(value, 1);
|
||||
}
|
||||
|
||||
to_chars_result ToCharsUnsignedImpl(char *first, char *last,
|
||||
uint64_t const value) {
|
||||
const uint32_t output_len = ShortestDigit10(value);
|
||||
to_chars_result ret;
|
||||
if (XGBOOST_EXPECT(std::distance(first, last) == 0, false)) {
|
||||
ret.ec = std::errc::value_too_large;
|
||||
ret.ptr = last;
|
||||
return ret;
|
||||
}
|
||||
|
||||
ItoaUnsignedImpl(first, output_len, value);
|
||||
ret.ptr = first + output_len;
|
||||
ret.ec = std::errc();
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* The parsing is also part of ryu. As of writing, the implementation in ryu uses full
|
||||
* double table. But here we optimize the table size with float table instead. The
|
||||
* result is exactly the same.
|
||||
*/
|
||||
from_chars_result FromCharFloatImpl(const char *buffer, const int len,
|
||||
float *result) {
|
||||
if (len == 0) {
|
||||
return {buffer, std::errc::invalid_argument};
|
||||
}
|
||||
int32_t m10digits = 0;
|
||||
int32_t e10digits = 0;
|
||||
int32_t dot_ind = len;
|
||||
int32_t e_ind = len;
|
||||
uint32_t mantissa_b10 = 0;
|
||||
int32_t exp_b10 = 0;
|
||||
bool signed_mantissa = false;
|
||||
bool signed_exp = false;
|
||||
int32_t i = 0;
|
||||
if (buffer[i] == '-') {
|
||||
signed_mantissa = true;
|
||||
i++;
|
||||
}
|
||||
for (; i < len; i++) {
|
||||
char c = buffer[i];
|
||||
if (c == '.') {
|
||||
if (dot_ind != len) {
|
||||
return {buffer + i, std::errc::invalid_argument};
|
||||
}
|
||||
dot_ind = i;
|
||||
continue;
|
||||
}
|
||||
if ((c < '0') || (c > '9')) {
|
||||
break;
|
||||
}
|
||||
if (m10digits >= 9) {
|
||||
return {buffer + i, std::errc::result_out_of_range};
|
||||
}
|
||||
mantissa_b10 = 10 * mantissa_b10 + (c - '0');
|
||||
if (mantissa_b10 != 0) {
|
||||
m10digits++;
|
||||
}
|
||||
}
|
||||
|
||||
if (i < len && ((buffer[i] == 'e') || (buffer[i] == 'E'))) {
|
||||
e_ind = i;
|
||||
i++;
|
||||
if (i < len && ((buffer[i] == '-') || (buffer[i] == '+'))) {
|
||||
signed_exp = buffer[i] == '-';
|
||||
i++;
|
||||
}
|
||||
for (; i < len; i++) {
|
||||
char c = buffer[i];
|
||||
if ((c < '0') || (c > '9')) {
|
||||
return {buffer + i, std::errc::invalid_argument};
|
||||
}
|
||||
if (e10digits > 3) {
|
||||
return {buffer + i, std::errc::result_out_of_range};
|
||||
}
|
||||
exp_b10 = 10 * exp_b10 + (c - '0');
|
||||
if (exp_b10 != 0) {
|
||||
e10digits++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (i < len) {
|
||||
return {buffer + i, std::errc::invalid_argument};
|
||||
}
|
||||
if (signed_exp) {
|
||||
exp_b10 = -exp_b10;
|
||||
}
|
||||
exp_b10 -= dot_ind < e_ind ? e_ind - dot_ind - 1 : 0;
|
||||
if (mantissa_b10 == 0) {
|
||||
*result = signed_mantissa ? -0.0f : 0.0f;
|
||||
return {};
|
||||
}
|
||||
|
||||
if ((m10digits + exp_b10 <= -46) || (mantissa_b10 == 0)) {
|
||||
// Number is less than 1e-46, which should be rounded down to 0; return
|
||||
// +/-0.0.
|
||||
uint32_t ieee =
|
||||
(static_cast<uint32_t>(signed_mantissa))
|
||||
<< (IEEE754::kFloatExponentBits + IEEE754::kFloatMantissaBits);
|
||||
*result = BitCast<float>(ieee);
|
||||
return {};
|
||||
}
|
||||
if (m10digits + exp_b10 >= 40) {
|
||||
// Number is larger than 1e+39, which should be rounded to +/-Infinity.
|
||||
*result = IEEE754::Infinity(signed_mantissa);
|
||||
return {};
|
||||
}
|
||||
|
||||
// Convert to binary float m2 * 2^e2, while retaining information about
|
||||
// whether the conversion was exact (trailingZeros).
|
||||
int32_t exp_b2;
|
||||
uint32_t mantissa_b2;
|
||||
bool trailing_zeros;
|
||||
if (exp_b10 >= 0) {
|
||||
// The length of m * 10^e in bits is:
|
||||
// log2(m10 * 10^e10) = log2(m10) + e10 log2(10) = log2(m10) + e10 + e10 *
|
||||
// log2(5)
|
||||
//
|
||||
// We want to compute the IEEE754::kFloatMantissaBits + 1 top-most bits (+1 for the
|
||||
// implicit leading one in IEEE format). We therefore choose a binary output
|
||||
// exponent of
|
||||
// log2(m10 * 10^e10) - (IEEE754::kFloatMantissaBits + 1).
|
||||
//
|
||||
// We use floor(log2(5^e10)) so that we get at least this many bits; better
|
||||
// to have an additional bit than to not have enough bits.
|
||||
exp_b2 = RyuPowLogUtils::FloorLog2(mantissa_b10) + exp_b10 +
|
||||
RyuPowLogUtils::Log2Pow5(exp_b10) -
|
||||
(IEEE754::kFloatMantissaBits + 1);
|
||||
|
||||
// We now compute [m10 * 10^e10 / 2^e2] = [m10 * 5^e10 / 2^(e2-e10)].
|
||||
// To that end, we use the RyuPowLogUtils::kFloatPow5Bitcount table.
|
||||
int j = exp_b2 - exp_b10 - RyuPowLogUtils::CeilLog2Pow5(exp_b10) +
|
||||
RyuPowLogUtils::kFloatPow5Bitcount;
|
||||
assert(j >= 0);
|
||||
mantissa_b2 = RyuPowLogUtils::MulPow5divPow2(mantissa_b10, exp_b10, j);
|
||||
|
||||
// We also compute if the result is exact, i.e.,
|
||||
// [m10 * 10^e10 / 2^e2] == m10 * 10^e10 / 2^e2.
|
||||
// This can only be the case if 2^e2 divides m10 * 10^e10, which in turn
|
||||
// requires that the largest power of 2 that divides m10 + e10 is greater
|
||||
// than e2. If e2 is less than e10, then the result must be exact. Otherwise
|
||||
// we use the existing multipleOfPowerOf2 function.
|
||||
trailing_zeros =
|
||||
exp_b2 < exp_b10 ||
|
||||
(exp_b2 - exp_b10 < 32 &&
|
||||
RyuPowLogUtils::MultipleOfPowerOf2(mantissa_b10, exp_b2 - exp_b10));
|
||||
} else {
|
||||
exp_b2 = RyuPowLogUtils::FloorLog2(mantissa_b10) + exp_b10 -
|
||||
RyuPowLogUtils::CeilLog2Pow5(-exp_b10) -
|
||||
(IEEE754::kFloatMantissaBits + 1);
|
||||
|
||||
// We now compute [m10 * 10^e10 / 2^e2] = [m10 / (5^(-e10) 2^(e2-e10))].
|
||||
int j = exp_b2 - exp_b10 + RyuPowLogUtils::CeilLog2Pow5(-exp_b10) - 1 +
|
||||
RyuPowLogUtils::kFloatPow5InvBitcount;
|
||||
mantissa_b2 = RyuPowLogUtils::MulPow5InvDivPow2(mantissa_b10, -exp_b10, j);
|
||||
|
||||
// We also compute if the result is exact, i.e.,
|
||||
// [m10 / (5^(-e10) 2^(e2-e10))] == m10 / (5^(-e10) 2^(e2-e10))
|
||||
//
|
||||
// If e2-e10 >= 0, we need to check whether (5^(-e10) 2^(e2-e10)) divides
|
||||
// m10, which is the case iff pow5(m10) >= -e10 AND pow2(m10) >= e2-e10.
|
||||
//
|
||||
// If e2-e10 < 0, we have actually computed [m10 * 2^(e10 e2) / 5^(-e10)]
|
||||
// above, and we need to check whether 5^(-e10) divides (m10 * 2^(e10-e2)),
|
||||
// which is the case iff pow5(m10 * 2^(e10-e2)) = pow5(m10) >= -e10.
|
||||
trailing_zeros =
|
||||
(exp_b2 < exp_b10 ||
|
||||
(exp_b2 - exp_b10 < 32 && RyuPowLogUtils::MultipleOfPowerOf2(
|
||||
mantissa_b10, exp_b2 - exp_b10))) &&
|
||||
RyuPowLogUtils::MultipleOfPowerOf5(mantissa_b10, -exp_b10);
|
||||
}
|
||||
|
||||
// Compute the final IEEE exponent.
|
||||
uint32_t f_e2 =
|
||||
std::max(static_cast<int32_t>(0),
|
||||
static_cast<int32_t>(exp_b2 + IEEE754::kFloatBias +
|
||||
RyuPowLogUtils::FloorLog2(mantissa_b2)));
|
||||
|
||||
if (f_e2 > 0xfe) {
|
||||
// Final IEEE exponent is larger than the maximum representable; return
|
||||
// +/-Infinity.
|
||||
*result = IEEE754::Infinity(signed_mantissa);
|
||||
return {};
|
||||
}
|
||||
|
||||
// We need to figure out how much we need to shift m2. The tricky part is that
|
||||
// we need to take the final IEEE exponent into account, so we need to reverse
|
||||
// the bias and also special-case the value 0.
|
||||
int32_t shift = (f_e2 == 0 ? 1 : f_e2) - exp_b2 - IEEE754::kFloatBias -
|
||||
IEEE754::kFloatMantissaBits;
|
||||
assert(shift >= 0);
|
||||
|
||||
// We need to round up if the exact value is more than 0.5 above the value we
|
||||
// computed. That's equivalent to checking if the last removed bit was 1 and
|
||||
// either the value was not just trailing zeros or the result would otherwise
|
||||
// be odd.
|
||||
//
|
||||
// We need to update trailingZeros given that we have the exact output
|
||||
// exponent ieee_e2 now.
|
||||
trailing_zeros &= (mantissa_b2 & ((1u << (shift - 1)) - 1)) == 0;
|
||||
uint32_t lastRemovedBit = (mantissa_b2 >> (shift - 1)) & 1;
|
||||
bool roundup = (lastRemovedBit != 0) &&
|
||||
(!trailing_zeros || (((mantissa_b2 >> shift) & 1) != 0));
|
||||
|
||||
uint32_t f_m2 = (mantissa_b2 >> shift) + roundup;
|
||||
assert(f_m2 <= (1u << (IEEE754::kFloatMantissaBits + 1)));
|
||||
f_m2 &= (1u << IEEE754::kFloatMantissaBits) - 1;
|
||||
if (f_m2 == 0 && roundup) {
|
||||
// Rounding up may overflow the mantissa.
|
||||
// In this case we move a trailing zero of the mantissa into the exponent.
|
||||
// Due to how the IEEE represents +/-Infinity, we don't need to check for
|
||||
// overflow here.
|
||||
f_e2++;
|
||||
}
|
||||
*result = IEEE754::Encode({f_m2, f_e2}, signed_mantissa);
|
||||
return {};
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace xgboost
|
||||
103
src/common/charconv.h
Normal file
103
src/common/charconv.h
Normal file
@ -0,0 +1,103 @@
|
||||
/*!
|
||||
* Copyright 2019 by XGBoost Contributors
|
||||
*
|
||||
* \brief Implement `std::to_chars` and `std::from_chars` for float. Only base 10 with
|
||||
* scientific format is supported. The implementation guarantees roundtrip
|
||||
* reproducibility.
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_CHARCONV_H_
|
||||
#define XGBOOST_COMMON_CHARCONV_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <system_error>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
struct to_chars_result { // NOLINT
|
||||
char* ptr;
|
||||
std::errc ec;
|
||||
};
|
||||
|
||||
struct from_chars_result { // NOLINT
|
||||
const char *ptr;
|
||||
std::errc ec;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
int32_t ToCharsFloatImpl(float f, char * const result);
|
||||
to_chars_result ToCharsUnsignedImpl(char *first, char *last,
|
||||
uint64_t const value);
|
||||
from_chars_result FromCharFloatImpl(const char *buffer, const int len,
|
||||
float *result);
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
struct NumericLimits;
|
||||
|
||||
template <> struct NumericLimits<float> {
|
||||
// Unlike std::numeric_limit<float>::max_digits10, which represents the **minimum**
|
||||
// length of base10 digits that are necessary to uniquely represent all distinct values.
|
||||
// This value is used to represent the maximum length. As sign bit occupies 1 character:
|
||||
// sign + len(str(2^24)) + decimal point + `E` + sign + len(str(2^8)) + '\0'
|
||||
static constexpr size_t kToCharsSize = 16;
|
||||
};
|
||||
|
||||
template <> struct NumericLimits<int64_t> {
|
||||
// From llvm libcxx: numeric_limits::digits10 returns value less on 1 than desired for
|
||||
// unsigned numbers. For example, for 1-byte unsigned value digits10 is 2 (999 can not
|
||||
// be represented), so we need +1 here.
|
||||
static constexpr size_t kToCharsSize =
|
||||
std::numeric_limits<int64_t>::digits10 +
|
||||
3; // +1 for minus, +1 for digits10, +1 for '\0' just to be safe.
|
||||
};
|
||||
|
||||
inline to_chars_result to_chars(char *first, char *last, float value) { // NOLINT
|
||||
if (XGBOOST_EXPECT(!(static_cast<size_t>(last - first) >=
|
||||
NumericLimits<float>::kToCharsSize),
|
||||
false)) {
|
||||
return {first, std::errc::value_too_large};
|
||||
}
|
||||
auto index = detail::ToCharsFloatImpl(value, first);
|
||||
to_chars_result ret;
|
||||
ret.ptr = first + index;
|
||||
|
||||
if (XGBOOST_EXPECT(ret.ptr < last, true)) {
|
||||
ret.ec = std::errc();
|
||||
} else {
|
||||
ret.ec = std::errc::value_too_large;
|
||||
ret.ptr = last;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline to_chars_result to_chars(char *first, char *last, int64_t value) { // NOLINT
|
||||
if (XGBOOST_EXPECT(first == last, false)) {
|
||||
return {first, std::errc::value_too_large};
|
||||
}
|
||||
// first write '-' and convert to unsigned, then write the rest.
|
||||
if (value == 0) {
|
||||
*first = '0';
|
||||
return {std::next(first), std::errc()};
|
||||
}
|
||||
uint64_t unsigned_value = value;
|
||||
if (value < 0) {
|
||||
*first = '-';
|
||||
std::advance(first, 1);
|
||||
unsigned_value = uint64_t(~value) + uint64_t(1);
|
||||
}
|
||||
return detail::ToCharsUnsignedImpl(first, last, unsigned_value);
|
||||
}
|
||||
|
||||
inline from_chars_result from_chars(const char *buffer, const char *end, // NOLINT
|
||||
float &value) { // NOLINT
|
||||
from_chars_result res =
|
||||
detail::FromCharFloatImpl(buffer, std::distance(buffer, end), &value);
|
||||
return res;
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_COMMON_CHARCONV_H_
|
||||
@ -1,12 +1,15 @@
|
||||
/*!
|
||||
* Copyright (c) by Contributors 2019
|
||||
* Copyright (c) by Contributors 2019-2020
|
||||
*/
|
||||
#include <cctype>
|
||||
#include <cstddef>
|
||||
#include <iterator>
|
||||
#include <locale>
|
||||
#include <sstream>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
|
||||
#include "charconv.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
@ -19,56 +22,68 @@ void JsonWriter::Save(Json json) {
|
||||
}
|
||||
|
||||
void JsonWriter::Visit(JsonArray const* arr) {
|
||||
this->Write("[");
|
||||
stream_->emplace_back('[');
|
||||
auto const& vec = arr->GetArray();
|
||||
size_t size = vec.size();
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
auto const& value = vec[i];
|
||||
this->Save(value);
|
||||
if (i != size-1) { Write(","); }
|
||||
if (i != size - 1) {
|
||||
stream_->emplace_back(',');
|
||||
}
|
||||
this->Write("]");
|
||||
}
|
||||
stream_->emplace_back(']');
|
||||
}
|
||||
|
||||
void JsonWriter::Visit(JsonObject const* obj) {
|
||||
this->Write("{");
|
||||
this->BeginIndent();
|
||||
this->NewLine();
|
||||
|
||||
stream_->emplace_back('{');
|
||||
size_t i = 0;
|
||||
size_t size = obj->GetObject().size();
|
||||
|
||||
for (auto& value : obj->GetObject()) {
|
||||
this->Write("\"" + value.first + "\":");
|
||||
auto s = String{value.first};
|
||||
this->Visit(&s);
|
||||
stream_->emplace_back(':');
|
||||
this->Save(value.second);
|
||||
|
||||
if (i != size-1) {
|
||||
this->Write(",");
|
||||
this->NewLine();
|
||||
stream_->emplace_back(',');
|
||||
}
|
||||
i++;
|
||||
}
|
||||
this->EndIndent();
|
||||
this->NewLine();
|
||||
this->Write("}");
|
||||
|
||||
stream_->emplace_back('}');
|
||||
}
|
||||
|
||||
void JsonWriter::Visit(JsonNumber const* num) {
|
||||
convertor_ << num->GetNumber();
|
||||
auto const& str = convertor_.str();
|
||||
this->Write(StringView{str.c_str(), str.size()});
|
||||
convertor_.str("");
|
||||
char number[NumericLimits<float>::kToCharsSize];
|
||||
auto res = to_chars(number, number + sizeof(number), num->GetNumber());
|
||||
auto end = res.ptr;
|
||||
auto ori_size = stream_->size();
|
||||
stream_->resize(stream_->size() + end - number);
|
||||
std::memcpy(stream_->data() + ori_size, number, end - number);
|
||||
}
|
||||
|
||||
void JsonWriter::Visit(JsonInteger const* num) {
|
||||
convertor_ << num->GetInteger();
|
||||
auto const& str = convertor_.str();
|
||||
this->Write(StringView{str.c_str(), str.size()});
|
||||
convertor_.str("");
|
||||
char i2s_buffer_[NumericLimits<int64_t>::kToCharsSize];
|
||||
auto i = num->GetInteger();
|
||||
auto ret = to_chars(i2s_buffer_, i2s_buffer_ + NumericLimits<int64_t>::kToCharsSize, i);
|
||||
auto end = ret.ptr;
|
||||
CHECK(ret.ec == std::errc());
|
||||
auto digits = std::distance(i2s_buffer_, end);
|
||||
auto ori_size = stream_->size();
|
||||
stream_->resize(ori_size + digits);
|
||||
std::memcpy(stream_->data() + ori_size, i2s_buffer_, digits);
|
||||
}
|
||||
|
||||
void JsonWriter::Visit(JsonNull const* null) {
|
||||
this->Write("null");
|
||||
auto s = stream_->size();
|
||||
stream_->resize(s + 4);
|
||||
auto& buf = (*stream_);
|
||||
buf[s + 0] = 'n';
|
||||
buf[s + 1] = 'u';
|
||||
buf[s + 2] = 'l';
|
||||
buf[s + 3] = 'l';
|
||||
}
|
||||
|
||||
void JsonWriter::Visit(JsonString const* str) {
|
||||
@ -105,15 +120,30 @@ void JsonWriter::Visit(JsonString const* str) {
|
||||
}
|
||||
}
|
||||
buffer += '"';
|
||||
this->Write(buffer);
|
||||
|
||||
auto s = stream_->size();
|
||||
stream_->resize(s + buffer.size());
|
||||
std::memcpy(stream_->data() + s, buffer.data(), buffer.size());
|
||||
}
|
||||
|
||||
void JsonWriter::Visit(JsonBoolean const* boolean) {
|
||||
bool val = boolean->GetBoolean();
|
||||
auto s = stream_->size();
|
||||
if (val) {
|
||||
this->Write(u8"true");
|
||||
stream_->resize(s + 4);
|
||||
auto& buf = (*stream_);
|
||||
buf[s + 0] = 't';
|
||||
buf[s + 1] = 'r';
|
||||
buf[s + 2] = 'u';
|
||||
buf[s + 3] = 'e';
|
||||
} else {
|
||||
this->Write(u8"false");
|
||||
stream_->resize(s + 5);
|
||||
auto& buf = (*stream_);
|
||||
buf[s + 0] = 'f';
|
||||
buf[s + 1] = 'a';
|
||||
buf[s + 2] = 'l';
|
||||
buf[s + 3] = 's';
|
||||
buf[s + 4] = 'e';
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,7 +340,7 @@ Value & JsonNull::operator=(Value const &rhs) {
|
||||
}
|
||||
|
||||
void JsonNull::Save(JsonWriter* writer) {
|
||||
writer->Write("null");
|
||||
writer->Visit(this);
|
||||
}
|
||||
|
||||
// Json Boolean
|
||||
@ -354,7 +384,7 @@ Json JsonReader::Parse() {
|
||||
} else if ( c == '[' ) {
|
||||
return ParseArray();
|
||||
} else if ( c == '-' || std::isdigit(c) ||
|
||||
c == 'N' ) {
|
||||
c == 'N' || c == 'I') {
|
||||
// For now we only accept `NaN`, not `nan` as the later violiates LR(1) with `null`.
|
||||
return ParseNumber();
|
||||
} else if ( c == '\"' ) {
|
||||
@ -379,9 +409,13 @@ void JsonReader::Error(std::string msg) const {
|
||||
// just copy it.
|
||||
std::istringstream str_s(raw_str_.substr(0, raw_str_.size()));
|
||||
|
||||
msg += ", around character: " + std::to_string(cursor_.Pos());
|
||||
msg += ", around character position: " + std::to_string(cursor_.Pos());
|
||||
msg += '\n';
|
||||
|
||||
if (cursor_.Pos() == 0) {
|
||||
LOG(FATAL) << msg << ", \"" << str_s.str() << " \"";
|
||||
}
|
||||
|
||||
constexpr size_t kExtend = 8;
|
||||
auto beg = static_cast<int64_t>(cursor_.Pos()) -
|
||||
static_cast<int64_t>(kExtend) < 0 ? 0 : cursor_.Pos() - kExtend;
|
||||
@ -413,11 +447,15 @@ void JsonReader::Error(std::string msg) const {
|
||||
LOG(FATAL) << msg;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool IsSpace(char c) { return c == ' ' || c == '\n' || c == '\r' || c == '\t'; }
|
||||
} // anonymous namespace
|
||||
|
||||
// Json class
|
||||
void JsonReader::SkipSpaces() {
|
||||
while (cursor_.Pos() < raw_str_.size()) {
|
||||
char c = raw_str_[cursor_.Pos()];
|
||||
if (std::isspace(c)) {
|
||||
if (IsSpace(c)) {
|
||||
cursor_.Forward();
|
||||
} else {
|
||||
break;
|
||||
@ -438,7 +476,7 @@ void ParseStr(std::string const& str) {
|
||||
}
|
||||
|
||||
Json JsonReader::ParseString() {
|
||||
char ch { GetChar('\"') }; // NOLINT
|
||||
char ch { GetConsecutiveChar('\"') }; // NOLINT
|
||||
std::ostringstream output;
|
||||
std::string str;
|
||||
while (true) {
|
||||
@ -483,14 +521,14 @@ Json JsonReader::ParseNull() {
|
||||
Json JsonReader::ParseArray() {
|
||||
std::vector<Json> data;
|
||||
|
||||
char ch { GetChar('[') }; // NOLINT
|
||||
char ch { GetConsecutiveChar('[') }; // NOLINT
|
||||
while (true) {
|
||||
if (PeekNextChar() == ']') {
|
||||
GetChar(']');
|
||||
GetConsecutiveChar(']');
|
||||
return Json(std::move(data));
|
||||
}
|
||||
auto obj = Parse();
|
||||
data.push_back(obj);
|
||||
data.emplace_back(obj);
|
||||
ch = GetNextNonSpaceChar();
|
||||
if (ch == ']') break;
|
||||
if (ch != ',') {
|
||||
@ -502,14 +540,14 @@ Json JsonReader::ParseArray() {
|
||||
}
|
||||
|
||||
Json JsonReader::ParseObject() {
|
||||
GetChar('{');
|
||||
GetConsecutiveChar('{');
|
||||
|
||||
std::map<std::string, Json> data;
|
||||
SkipSpaces();
|
||||
char ch = PeekNextChar();
|
||||
|
||||
if (ch == '}') {
|
||||
GetChar('}');
|
||||
GetConsecutiveChar('}');
|
||||
return Json(std::move(data));
|
||||
}
|
||||
|
||||
@ -550,118 +588,105 @@ Json JsonReader::ParseNumber() {
|
||||
char const* const beg = p; // keep track of current pointer
|
||||
|
||||
// TODO(trivialfis): Add back all the checks for number
|
||||
bool negative = false;
|
||||
if (XGBOOST_EXPECT(*p == 'N', false)) {
|
||||
GetChar('N');
|
||||
GetChar('a');
|
||||
GetChar('N');
|
||||
GetConsecutiveChar('N');
|
||||
GetConsecutiveChar('a');
|
||||
GetConsecutiveChar('N');
|
||||
return Json(static_cast<Number::Float>(std::numeric_limits<float>::quiet_NaN()));
|
||||
}
|
||||
|
||||
if ('-' == *p) {
|
||||
++p;
|
||||
bool negative = false;
|
||||
switch (*p) {
|
||||
case '-': {
|
||||
negative = true;
|
||||
++p;
|
||||
break;
|
||||
}
|
||||
case '+': {
|
||||
negative = false;
|
||||
++p;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (XGBOOST_EXPECT(*p == 'I', false)) {
|
||||
cursor_.Forward(std::distance(beg, p)); // +/-
|
||||
for (auto i : {'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'}) {
|
||||
GetConsecutiveChar(i);
|
||||
}
|
||||
auto f = std::numeric_limits<float>::infinity();
|
||||
if (negative) {
|
||||
f = -f;
|
||||
}
|
||||
return Json(static_cast<Number::Float>(f));
|
||||
}
|
||||
|
||||
bool is_float = false;
|
||||
|
||||
using ExpInt = std::remove_const<
|
||||
decltype(std::numeric_limits<Number::Float>::max_exponent)>::type;
|
||||
constexpr auto kExpMax = std::numeric_limits<ExpInt>::max();
|
||||
constexpr auto kExpMin = std::numeric_limits<ExpInt>::min();
|
||||
|
||||
JsonInteger::Int i = 0;
|
||||
double f = 0.0; // Use double to maintain accuracy
|
||||
int64_t i = 0;
|
||||
|
||||
if (*p == '0') {
|
||||
++p;
|
||||
i = 0;
|
||||
p++;
|
||||
}
|
||||
|
||||
while (XGBOOST_EXPECT(*p >= '0' && *p <= '9', true)) {
|
||||
i = i * 10 + (*p - '0');
|
||||
p++;
|
||||
}
|
||||
|
||||
if (*p == '.') {
|
||||
p++;
|
||||
is_float = true;
|
||||
|
||||
while (*p >= '0' && *p <= '9') {
|
||||
i = i * 10 + (*p - '0');
|
||||
p++;
|
||||
}
|
||||
}
|
||||
|
||||
if (*p == 'E' || *p == 'e') {
|
||||
is_float = true;
|
||||
p++;
|
||||
|
||||
switch (*p) {
|
||||
case '-':
|
||||
case '+': {
|
||||
p++;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (XGBOOST_EXPECT(*p >= '0' && *p <= '9', true)) {
|
||||
p++;
|
||||
while (*p >= '0' && *p <= '9') {
|
||||
p++;
|
||||
}
|
||||
} else {
|
||||
char c = *p;
|
||||
do {
|
||||
++p;
|
||||
char digit = c - '0';
|
||||
i = 10 * i + digit;
|
||||
c = *p;
|
||||
} while (std::isdigit(c));
|
||||
Error("Expecting digit");
|
||||
}
|
||||
|
||||
ExpInt exponent = 0;
|
||||
const char *const dot_position = p;
|
||||
if ('.' == *p) {
|
||||
is_float = true;
|
||||
f = i;
|
||||
++p;
|
||||
char c = *p;
|
||||
|
||||
do {
|
||||
++p;
|
||||
f = f * 10 + (c - '0');
|
||||
c = *p;
|
||||
} while (std::isdigit(c));
|
||||
}
|
||||
if (is_float) {
|
||||
exponent = dot_position - p + 1;
|
||||
}
|
||||
|
||||
char e = *p;
|
||||
if ('e' == e || 'E' == e) {
|
||||
if (!is_float) {
|
||||
is_float = true;
|
||||
f = i;
|
||||
}
|
||||
++p;
|
||||
|
||||
bool negative_exponent = false;
|
||||
if ('-' == *p) {
|
||||
negative_exponent = true;
|
||||
++p;
|
||||
} else if ('+' == *p) {
|
||||
++p;
|
||||
}
|
||||
|
||||
ExpInt exp = 0;
|
||||
|
||||
char c = *p;
|
||||
while (std::isdigit(c)) {
|
||||
unsigned char digit = c - '0';
|
||||
if (XGBOOST_EXPECT(exp > (kExpMax - digit) / 10, false)) {
|
||||
CHECK_GT(exp, (kExpMax - digit) / 10) << "Overflow";
|
||||
}
|
||||
exp = 10 * exp + digit;
|
||||
++p;
|
||||
c = *p;
|
||||
}
|
||||
static_assert(-kExpMax >= kExpMin, "exp can be negated without loss or UB");
|
||||
exponent += (negative_exponent ? -exp : exp);
|
||||
}
|
||||
|
||||
if (exponent) {
|
||||
CHECK(is_float);
|
||||
// If d is zero but the exponent is huge, don't
|
||||
// multiply zero by inf which gives nan.
|
||||
if (f != 0.0) {
|
||||
// Only use exp10 from libc on gcc+linux
|
||||
#if !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) || !defined(__linux__)
|
||||
#define exp10(val) std::pow(10, (val))
|
||||
#endif // !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) || !defined(__linux__)
|
||||
f *= exp10(exponent);
|
||||
#if !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) || !defined(__linux__)
|
||||
#undef exp10
|
||||
#endif // !defined(__GNUC__) || defined(_WIN32) || defined(__APPLE__) || !defined(__linux__)
|
||||
}
|
||||
}
|
||||
|
||||
if (negative) {
|
||||
f = -f;
|
||||
i = -i;
|
||||
}
|
||||
|
||||
auto moved = std::distance(beg, p);
|
||||
this->cursor_.Forward(moved);
|
||||
|
||||
if (is_float) {
|
||||
float f;
|
||||
auto ret = from_chars(beg, p, f);
|
||||
if (XGBOOST_EXPECT(ret.ec != std::errc(), false)) {
|
||||
// Compatible with old format that generates very long mantissa from std stream.
|
||||
f = std::strtof(beg, nullptr);
|
||||
}
|
||||
return Json(static_cast<Number::Float>(f));
|
||||
} else {
|
||||
if (negative) {
|
||||
i = -i;
|
||||
}
|
||||
return Json(JsonInteger(i));
|
||||
}
|
||||
}
|
||||
@ -674,20 +699,15 @@ Json JsonReader::ParseBoolean() {
|
||||
std::string buffer;
|
||||
|
||||
if (ch == 't') {
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
buffer.push_back(GetNextNonSpaceChar());
|
||||
}
|
||||
if (buffer != u8"rue") {
|
||||
Error("Expecting boolean value \"true\".");
|
||||
}
|
||||
GetConsecutiveChar('r');
|
||||
GetConsecutiveChar('u');
|
||||
GetConsecutiveChar('e');
|
||||
result = true;
|
||||
} else {
|
||||
for (size_t i = 0; i < 4; ++i) {
|
||||
buffer.push_back(GetNextNonSpaceChar());
|
||||
}
|
||||
if (buffer != u8"alse") {
|
||||
Error("Expecting boolean value \"false\".");
|
||||
}
|
||||
GetConsecutiveChar('a');
|
||||
GetConsecutiveChar('l');
|
||||
GetConsecutiveChar('s');
|
||||
GetConsecutiveChar('e');
|
||||
result = false;
|
||||
}
|
||||
return Json{JsonBoolean{result}};
|
||||
@ -704,16 +724,12 @@ Json Json::Load(JsonReader* reader) {
|
||||
return json;
|
||||
}
|
||||
|
||||
void Json::Dump(Json json, std::ostream *stream, bool pretty) {
|
||||
JsonWriter writer(stream, pretty);
|
||||
void Json::Dump(Json json, std::string* str) {
|
||||
std::vector<char> buffer;
|
||||
JsonWriter writer(&buffer);
|
||||
writer.Save(json);
|
||||
}
|
||||
|
||||
void Json::Dump(Json json, std::string* str, bool pretty) {
|
||||
std::stringstream ss;
|
||||
JsonWriter writer(&ss, pretty);
|
||||
writer.Save(json);
|
||||
*str = ss.str();
|
||||
str->resize(buffer.size());
|
||||
std::copy(buffer.cbegin(), buffer.cend(), str->begin());
|
||||
}
|
||||
|
||||
Json& Json::operator=(Json const &other) = default;
|
||||
|
||||
@ -53,7 +53,7 @@ class TrainingObserver {
|
||||
Json j_tree {Object()};
|
||||
tree.SaveModel(&j_tree);
|
||||
std::string str;
|
||||
Json::Dump(j_tree, &str, true);
|
||||
Json::Dump(j_tree, &str);
|
||||
OBSERVER_PRINT << str << OBSERVER_ENDL;
|
||||
}
|
||||
/*\brief Observe tree. */
|
||||
|
||||
@ -61,9 +61,8 @@ std::vector<Monitor::StatMap> Monitor::CollectFromOtherRanks() const {
|
||||
kv.second.timer.elapsed).count()));
|
||||
}
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(j_statistic, &ss);
|
||||
std::string const str { ss.str() };
|
||||
std::string str;
|
||||
Json::Dump(j_statistic, &str);
|
||||
|
||||
size_t str_size = str.size();
|
||||
rabit::Allreduce<rabit::op::Max>(&str_size, 1);
|
||||
|
||||
213
tests/cpp/common/test_charconv.cc
Normal file
213
tests/cpp/common/test_charconv.cc
Normal file
@ -0,0 +1,213 @@
|
||||
/*
|
||||
* The code is adopted from original (half) c implementation:
|
||||
* https://github.com/ulfjack/ryu.git with some more comments and tidying. License is
|
||||
* attached below.
|
||||
*
|
||||
* Copyright 2018 Ulf Adams
|
||||
*
|
||||
* The contents of this file may be used under the terms of the Apache License,
|
||||
* Version 2.0.
|
||||
*
|
||||
* (See accompanying file LICENSE-Apache or copy at
|
||||
* http: *www.apache.org/licenses/LICENSE-2.0)
|
||||
*
|
||||
* Alternatively, the contents of this file may be used under the terms of
|
||||
* the Boost Software License, Version 1.0.
|
||||
* (See accompanying file LICENSE-Boost or copy at
|
||||
* https://www.boost.org/LICENSE_1_0.txt)
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, this software
|
||||
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
* KIND, either express or implied.
|
||||
*/
|
||||
#include <cstddef>
|
||||
#include <gtest/gtest.h>
|
||||
#include <limits>
|
||||
#include "../../../src/common/charconv.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace {
|
||||
void TestInteger(char const* res, int64_t i) {
|
||||
char result[xgboost::NumericLimits<int64_t>::kToCharsSize];
|
||||
auto ret = to_chars(result, result + sizeof(result), i);
|
||||
*ret.ptr = '\0';
|
||||
EXPECT_STREQ(res, result);
|
||||
}
|
||||
|
||||
static float Int32Bits2Float(uint32_t bits) {
|
||||
float f;
|
||||
memcpy(&f, &bits, sizeof(float));
|
||||
return f;
|
||||
}
|
||||
|
||||
void TestRyu(char const *res, float v) {
|
||||
char result[xgboost::NumericLimits<float>::kToCharsSize];
|
||||
auto ret = to_chars(result, result + sizeof(result), v);
|
||||
*ret.ptr = '\0';
|
||||
EXPECT_STREQ(res, result);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Ryu, Subnormal) {
|
||||
TestRyu("0E0", 0.0f);
|
||||
TestRyu("-0E0", -0.0f);
|
||||
TestRyu("1E0", 1.0f);
|
||||
TestRyu("-1E0", -1.0f);
|
||||
TestRyu("NaN", NAN);
|
||||
TestRyu("Infinity", INFINITY);
|
||||
TestRyu("-Infinity", -INFINITY);
|
||||
|
||||
TestRyu("1E-45", std::numeric_limits<float>::denorm_min());
|
||||
}
|
||||
|
||||
TEST(Ryu, Denormal) {
|
||||
TestRyu("1E-45", std::numeric_limits<float>::denorm_min());
|
||||
}
|
||||
|
||||
TEST(Ryu, SwitchToSubnormal) {
|
||||
TestRyu("1.1754944E-38", 1.1754944E-38f);
|
||||
}
|
||||
|
||||
TEST(Ryu, MinAndMax) {
|
||||
TestRyu("3.4028235E38", Int32Bits2Float(0x7f7fffff));
|
||||
TestRyu("1E-45", Int32Bits2Float(1));
|
||||
}
|
||||
|
||||
// Check that we return the exact boundary if it is the shortest
|
||||
// representation, but only if the original floating point number is even.
|
||||
TEST(Ryu, BoundaryRoundEven) {
|
||||
TestRyu("3.355445E7", 3.355445E7f);
|
||||
TestRyu("9E9", 8.999999E9f);
|
||||
TestRyu("3.436672E10", 3.4366717E10f);
|
||||
}
|
||||
|
||||
// If the exact value is exactly halfway between two shortest representations,
|
||||
// then we round to even. It seems like this only makes a difference if the
|
||||
// last two digits are ...2|5 or ...7|5, and we cut off the 5.
|
||||
TEST(Ryu, ExactValueRoundEven) {
|
||||
TestRyu("3.0540412E5", 3.0540412E5f);
|
||||
TestRyu("8.0990312E3", 8.0990312E3f);
|
||||
}
|
||||
|
||||
TEST(Ryu, LotsOfTrailingZeros) {
|
||||
// Pattern for the first test: 00111001100000000000000000000000
|
||||
TestRyu("2.4414062E-4", 2.4414062E-4f);
|
||||
TestRyu("2.4414062E-3", 2.4414062E-3f);
|
||||
TestRyu("4.3945312E-3", 4.3945312E-3f);
|
||||
TestRyu("6.3476562E-3", 6.3476562E-3f);
|
||||
}
|
||||
|
||||
TEST(Ryu, Regression) {
|
||||
TestRyu("4.7223665E21", 4.7223665E21f);
|
||||
TestRyu("8.388608E6", 8388608.0f);
|
||||
TestRyu("1.6777216E7", 1.6777216E7f);
|
||||
TestRyu("3.3554436E7", 3.3554436E7f);
|
||||
TestRyu("6.7131496E7", 6.7131496E7f);
|
||||
TestRyu("1.9310392E-38", 1.9310392E-38f);
|
||||
TestRyu("-2.47E-43", -2.47E-43f);
|
||||
TestRyu("1.993244E-38", 1.993244E-38f);
|
||||
TestRyu("4.1039004E3", 4103.9003f);
|
||||
TestRyu("5.3399997E9", 5.3399997E9f);
|
||||
TestRyu("6.0898E-39", 6.0898E-39f);
|
||||
TestRyu("1.0310042E-3", 0.0010310042f);
|
||||
TestRyu("2.882326E17", 2.8823261E17f);
|
||||
TestRyu("7.038531E-26", 7.0385309E-26f);
|
||||
TestRyu("9.223404E17", 9.2234038E17f);
|
||||
TestRyu("6.710887E7", 6.7108872E7f);
|
||||
TestRyu("1E-44", 1.0E-44f);
|
||||
TestRyu("2.816025E14", 2.816025E14f);
|
||||
TestRyu("9.223372E18", 9.223372E18f);
|
||||
TestRyu("1.5846086E29", 1.5846085E29f);
|
||||
TestRyu("1.1811161E19", 1.1811161E19f);
|
||||
TestRyu("5.368709E18", 5.368709E18f);
|
||||
TestRyu("4.6143166E18", 4.6143165E18f);
|
||||
TestRyu("7.812537E-3", 0.007812537f);
|
||||
TestRyu("1E-45", 1.4E-45f);
|
||||
TestRyu("1.18697725E20", 1.18697724E20f);
|
||||
TestRyu("1.00014165E-36", 1.00014165E-36f);
|
||||
TestRyu("2E2", 200.0f);
|
||||
TestRyu("3.3554432E7", 3.3554432E7f);
|
||||
|
||||
static_assert(1.1920929E-7f == std::numeric_limits<float>::epsilon(), "");
|
||||
TestRyu("1.1920929E-7", std::numeric_limits<float>::epsilon());
|
||||
}
|
||||
|
||||
TEST(Ryu, RoundTrip) {
|
||||
float f = -1.1493590134238582e-40;
|
||||
char result[NumericLimits<float>::kToCharsSize] { 0 };
|
||||
auto ret = to_chars(result, result + sizeof(result), f);
|
||||
size_t dis = std::distance(result, ret.ptr);
|
||||
float back;
|
||||
auto from_ret = from_chars(result, result + dis, back);
|
||||
ASSERT_EQ(from_ret.ec, std::errc());
|
||||
std::string str;
|
||||
for (size_t i = 0; i < dis; ++i) {
|
||||
str.push_back(result[i]);
|
||||
}
|
||||
ASSERT_EQ(f, back);
|
||||
}
|
||||
|
||||
TEST(Ryu, LooksLikePow5) {
|
||||
// These numbers have a mantissa that is the largest power of 5 that fits,
|
||||
// and an exponent that causes the computation for q to result in 10, which is a corner
|
||||
// case for Ryu.
|
||||
TestRyu("6.7108864E17", Int32Bits2Float(0x5D1502F9));
|
||||
TestRyu("1.3421773E18", Int32Bits2Float(0x5D9502F9));
|
||||
TestRyu("2.6843546E18", Int32Bits2Float(0x5E1502F9));
|
||||
}
|
||||
|
||||
TEST(Ryu, OutputLength) {
|
||||
TestRyu("1E0", 1.0f); // already tested in Basic
|
||||
TestRyu("1.2E0", 1.2f);
|
||||
TestRyu("1.23E0", 1.23f);
|
||||
TestRyu("1.234E0", 1.234f);
|
||||
TestRyu("1.2345E0", 1.2345f);
|
||||
TestRyu("1.23456E0", 1.23456f);
|
||||
TestRyu("1.234567E0", 1.234567f);
|
||||
TestRyu("1.2345678E0", 1.2345678f);
|
||||
TestRyu("1.23456735E-36", 1.23456735E-36f);
|
||||
}
|
||||
|
||||
TEST(IntegerPrinting, Basic) {
|
||||
TestInteger("0", 0);
|
||||
auto str = std::to_string(std::numeric_limits<int64_t>::min());
|
||||
TestInteger(str.c_str(), std::numeric_limits<int64_t>::min());
|
||||
str = std::to_string(std::numeric_limits<int64_t>::max());
|
||||
TestInteger(str.c_str(), std::numeric_limits<int64_t>::max());
|
||||
}
|
||||
|
||||
void TestRyuParse(float f, std::string in) {
|
||||
float res;
|
||||
auto ret = from_chars(in.c_str(), in.c_str() + in.size(), res);
|
||||
ASSERT_EQ(ret.ec, std::errc());
|
||||
ASSERT_EQ(f, res);
|
||||
}
|
||||
|
||||
TEST(Ryu, Basic) {
|
||||
TestRyuParse(0.0f, "0");
|
||||
TestRyuParse(-0.0f, "-0");
|
||||
TestRyuParse(1.0f, "1");
|
||||
TestRyuParse(-1.0f, "-1");
|
||||
TestRyuParse(123456792.0f, "123456789");
|
||||
TestRyuParse(299792448.0f, "299792458");
|
||||
}
|
||||
|
||||
TEST(Ryu, MinMax) {
|
||||
TestRyuParse(1e-45f, "1e-45");
|
||||
TestRyuParse(FLT_MIN, "1.1754944e-38");
|
||||
TestRyuParse(FLT_MAX, "3.4028235e+38");
|
||||
}
|
||||
|
||||
TEST(Ryu, MantissaRoundingOverflow) {
|
||||
TestRyuParse(1.0f, "0.999999999");
|
||||
TestRyuParse(INFINITY, "3.4028236e+38");
|
||||
TestRyuParse(1.1754944e-38f, "1.17549430e-38"); // FLT_MIN
|
||||
}
|
||||
|
||||
TEST(Ryu, TrailingZeros) {
|
||||
TestRyuParse(26843550.0f, "26843549.5");
|
||||
TestRyuParse(50000004.0f, "50000002.5");
|
||||
TestRyuParse(99999992.0f, "99999989.5");
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
@ -248,9 +248,9 @@ TEST(HistUtil, AdapterDeviceSketch) {
|
||||
thrust::device_vector< float> data(rows*cols);
|
||||
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
|
||||
data = std::vector<float >{ 1.0,2.0,3.0,4.0,5.0 };
|
||||
std::stringstream ss;
|
||||
Json::Dump(json_array_interface, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(json_array_interface, &str);
|
||||
|
||||
data::CupyAdapter adapter(str);
|
||||
|
||||
auto device_cuts = AdapterDeviceSketch(&adapter, num_bins, missing);
|
||||
|
||||
@ -54,9 +54,8 @@ inline data::CupyAdapter AdapterFromData(const thrust::device_vector<float> &x,
|
||||
array_interface["data"] = j_data;
|
||||
array_interface["version"] = Integer(static_cast<Integer::Int>(1));
|
||||
array_interface["typestr"] = String("<f4");
|
||||
std::stringstream ss;
|
||||
Json::Dump(array_interface, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(array_interface, &str);
|
||||
return data::CupyAdapter(str);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -9,7 +9,9 @@
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json_io.h"
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/common/io.h"
|
||||
#include "../../../src/common/charconv.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -146,22 +148,46 @@ TEST(Json, ParseNumber) {
|
||||
{
|
||||
std::string str = "31.8892";
|
||||
auto json = Json::Load(StringView{str.c_str(), str.size()});
|
||||
ASSERT_NEAR(get<JsonNumber>(json), 31.8892f, kRtEps);
|
||||
ASSERT_EQ(get<JsonNumber>(json), 31.8892f);
|
||||
}
|
||||
{
|
||||
std::string str = "-31.8892";
|
||||
auto json = Json::Load(StringView{str.c_str(), str.size()});
|
||||
ASSERT_NEAR(get<JsonNumber>(json), -31.8892f, kRtEps);
|
||||
ASSERT_EQ(get<JsonNumber>(json), -31.8892f);
|
||||
}
|
||||
{
|
||||
std::string str = "2e4";
|
||||
auto json = Json::Load(StringView{str.c_str(), str.size()});
|
||||
ASSERT_NEAR(get<JsonNumber>(json), 2e4f, kRtEps);
|
||||
ASSERT_EQ(get<JsonNumber>(json), 2e4f);
|
||||
}
|
||||
{
|
||||
std::string str = "2e-4";
|
||||
auto json = Json::Load(StringView{str.c_str(), str.size()});
|
||||
ASSERT_NEAR(get<JsonNumber>(json), 2e-4f, kRtEps);
|
||||
ASSERT_EQ(get<JsonNumber>(json), 2e-4f);
|
||||
}
|
||||
{
|
||||
std::string str = "-2e-4";
|
||||
auto json = Json::Load(StringView{str.c_str(), str.size()});
|
||||
ASSERT_EQ(get<JsonNumber>(json), -2e-4f);
|
||||
}
|
||||
{
|
||||
std::string str = "-0.0";
|
||||
auto json = Json::Load(StringView{str.c_str(), str.size()});
|
||||
ASSERT_TRUE(std::signbit(get<JsonNumber>(json)));
|
||||
ASSERT_EQ(get<JsonNumber>(json), -0);
|
||||
}
|
||||
{
|
||||
std::string str = "-5.37645816802978516e-01";
|
||||
auto json = Json::Load(StringView{str.c_str(), str.size()});
|
||||
ASSERT_TRUE(std::signbit(get<JsonNumber>(json)));
|
||||
// Larger than fast path limit.
|
||||
ASSERT_EQ(get<JsonNumber>(json), -5.37645816802978516e-01);
|
||||
}
|
||||
{
|
||||
std::string str = "9.86623668670654297e+00";
|
||||
auto json = Json::Load(StringView{str.c_str(), str.size()});
|
||||
ASSERT_FALSE(std::signbit(get<JsonNumber>(json)));
|
||||
ASSERT_EQ(get<JsonNumber>(json), 9.86623668670654297e+00);
|
||||
}
|
||||
}
|
||||
|
||||
@ -200,13 +226,30 @@ TEST(Json, ParseArray) {
|
||||
Json v0 = arr[0];
|
||||
ASSERT_EQ(get<Integer>(v0["depth"]), 3);
|
||||
ASSERT_NEAR(get<Number>(v0["gain"]), 10.4866, kRtEps);
|
||||
|
||||
{
|
||||
std::string str =
|
||||
"[5.04713470458984375e+02,9.86623668670654297e+00,4.94847229003906250e+"
|
||||
"02,2.13924217224121094e+00,7.72699451446533203e+00,2."
|
||||
"30380615234375000e+02,2.64466613769531250e+02]";
|
||||
auto json = Json::Load(StringView{str.c_str(), str.size()});
|
||||
|
||||
auto const& vec = get<Array const>(json);
|
||||
ASSERT_EQ(get<Number const>(vec[0]), 5.04713470458984375e+02);
|
||||
ASSERT_EQ(get<Number const>(vec[1]), 9.86623668670654297e+00);
|
||||
ASSERT_EQ(get<Number const>(vec[2]), 4.94847229003906250e+02);
|
||||
ASSERT_EQ(get<Number const>(vec[3]), 2.13924217224121094e+00);
|
||||
ASSERT_EQ(get<Number const>(vec[4]), 7.72699451446533203e+00);
|
||||
ASSERT_EQ(get<Number const>(vec[5]), 2.30380615234375000e+02);
|
||||
ASSERT_EQ(get<Number const>(vec[6]), 2.64466613769531250e+02);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Json, Null) {
|
||||
Json json {JsonNull()};
|
||||
std::stringstream ss;
|
||||
std::string ss;
|
||||
Json::Dump(json, &ss);
|
||||
ASSERT_EQ(ss.str(), "null");
|
||||
ASSERT_EQ(ss, "null");
|
||||
|
||||
std::string null_input {R"null({"key": null })null"};
|
||||
|
||||
@ -288,7 +331,7 @@ TEST(Json, AssigningObjects) {
|
||||
Json json_object { JsonObject() };
|
||||
auto str = JsonString("1");
|
||||
auto& k = json_object["1"];
|
||||
k = str;
|
||||
k = std::move(str);
|
||||
auto& m = json_object["1"];
|
||||
std::string value = get<JsonString>(m);
|
||||
ASSERT_EQ(value, "1");
|
||||
@ -365,15 +408,56 @@ TEST(Json, LoadDump) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
auto const& path = tempdir.path + "test_model_dump";
|
||||
|
||||
std::ofstream fout (path);
|
||||
Json::Dump(origin, &fout);
|
||||
fout.close();
|
||||
std::string out;
|
||||
Json::Dump(origin, &out);
|
||||
|
||||
std::ofstream fout(path);
|
||||
ASSERT_TRUE(fout);
|
||||
fout << out << std::flush;
|
||||
|
||||
std::string new_buffer = common::LoadSequentialFile(path);
|
||||
Json load_back {Json::Load(StringView(new_buffer.c_str(), new_buffer.size()))};
|
||||
|
||||
ASSERT_EQ(load_back, origin) << ori_buffer << "\n\n---------------\n\n"
|
||||
<< new_buffer;
|
||||
Json load_back {Json::Load(StringView(new_buffer.c_str(), new_buffer.size()))};
|
||||
ASSERT_EQ(load_back, origin);
|
||||
}
|
||||
|
||||
TEST(Json, Invalid) {
|
||||
{
|
||||
std::string str = "}";
|
||||
bool has_thrown = false;
|
||||
try {
|
||||
Json load{Json::Load(StringView(str.c_str(), str.size()))};
|
||||
} catch (dmlc::Error const &e) {
|
||||
std::string msg = e.what();
|
||||
ASSERT_NE(msg.find("Unknown"), std::string::npos);
|
||||
has_thrown = true;
|
||||
};
|
||||
ASSERT_TRUE(has_thrown);
|
||||
}
|
||||
{
|
||||
std::string str = R"json({foo)json";
|
||||
bool has_thrown = false;
|
||||
try {
|
||||
Json load{Json::Load(StringView(str.c_str(), str.size()))};
|
||||
} catch (dmlc::Error const &e) {
|
||||
std::string msg = e.what();
|
||||
ASSERT_NE(msg.find("position: 1"), std::string::npos);
|
||||
has_thrown = true;
|
||||
};
|
||||
ASSERT_TRUE(has_thrown);
|
||||
}
|
||||
{
|
||||
std::string str = R"json({"foo")json";
|
||||
bool has_thrown = false;
|
||||
try {
|
||||
Json load{Json::Load(StringView(str.c_str(), str.size()))};
|
||||
} catch (dmlc::Error const &e) {
|
||||
std::string msg = e.what();
|
||||
ASSERT_NE(msg.find("EOF"), std::string::npos);
|
||||
has_thrown = true;
|
||||
};
|
||||
ASSERT_TRUE(has_thrown);
|
||||
}
|
||||
}
|
||||
|
||||
// For now Json is quite ignorance about unicode.
|
||||
@ -383,10 +467,9 @@ TEST(Json, CopyUnicode) {
|
||||
)json";
|
||||
Json loaded {Json::Load(StringView{json_str.c_str(), json_str.size()})};
|
||||
|
||||
std::stringstream ss_1;
|
||||
Json::Dump(loaded, &ss_1);
|
||||
std::string dumped_string;
|
||||
Json::Dump(loaded, &dumped_string);
|
||||
|
||||
std::string dumped_string = ss_1.str();
|
||||
ASSERT_NE(dumped_string.find("\\u20ac"), std::string::npos);
|
||||
}
|
||||
|
||||
@ -406,6 +489,15 @@ TEST(Json, WrongCasts) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Json, Integer) {
|
||||
for (int64_t i = 1; i < 10000; i *= 10) {
|
||||
auto ten = Json{Integer{i}};
|
||||
std::string str;
|
||||
Json::Dump(ten, &str);
|
||||
ASSERT_EQ(str, std::to_string(i));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Json, IntVSFloat) {
|
||||
// If integer is parsed as float, calling `get<Integer>()' will throw.
|
||||
{
|
||||
@ -432,4 +524,31 @@ TEST(Json, IntVSFloat) {
|
||||
ASSERT_EQ(ptr, 2503595760);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Json, RoundTrip) {
|
||||
uint32_t i = 0;
|
||||
SimpleLCG rng;
|
||||
SimpleRealUniformDistribution<float> dist(1.0f, 4096.0f);
|
||||
|
||||
while (i <= std::numeric_limits<uint32_t>::max()) {
|
||||
float f;
|
||||
std::memcpy(&f, &i, sizeof(f));
|
||||
|
||||
Json jf { f };
|
||||
std::string str;
|
||||
Json::Dump(jf, &str);
|
||||
auto loaded = Json::Load({str.c_str(), str.size()});
|
||||
if (XGBOOST_EXPECT(std::isnan(f), false)) {
|
||||
ASSERT_TRUE(std::isnan(get<Number const>(loaded)));
|
||||
} else {
|
||||
ASSERT_EQ(get<Number const>(loaded), f);
|
||||
}
|
||||
|
||||
auto t = i;
|
||||
i+= static_cast<uint32_t>(dist(&rng));
|
||||
if (i < t) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -50,7 +50,7 @@ __global__ void TestFromOtherKernel(Span<float> span) {
|
||||
}
|
||||
}
|
||||
// Test converting different T
|
||||
__global__ void TestFromOtherKernelConst(Span<float const, 16> span) {
|
||||
__global__ void TestFromOtherKernelConst(Span<float const, 16> span) {
|
||||
// don't get optimized out
|
||||
size_t idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
|
||||
@ -23,9 +23,8 @@ void TestCudfAdapter()
|
||||
|
||||
Json column_arr {columns};
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(column_arr, &str);
|
||||
|
||||
data::CudfAdapter adapter(str);
|
||||
|
||||
|
||||
@ -78,9 +78,8 @@ TEST(DeviceDMatrix, ColumnMajor) {
|
||||
|
||||
Json column_arr{columns};
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(column_arr, &str);
|
||||
|
||||
data::CudfAdapter adapter(str);
|
||||
data::DeviceDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||
|
||||
@ -32,9 +32,8 @@ std::string PrepareData(std::string typestr, thrust::device_vector<T>* out, cons
|
||||
column["data"] = j_data;
|
||||
Json array(std::vector<Json>{column});
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(array, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(array, &str);
|
||||
|
||||
return str;
|
||||
}
|
||||
|
||||
@ -22,9 +22,8 @@ TEST(SimpleDMatrix, FromColumnarDenseBasic) {
|
||||
|
||||
Json column_arr{columns};
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(column_arr, &str);
|
||||
|
||||
data::CudfAdapter adapter(str);
|
||||
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||
@ -59,9 +58,8 @@ TEST(SimpleDMatrix, FromColumnarDense) {
|
||||
|
||||
Json column_arr{columns};
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(column_arr, &str);
|
||||
|
||||
// no missing value
|
||||
{
|
||||
@ -156,9 +154,9 @@ TEST(SimpleDMatrix, FromColumnarWithEmptyRows) {
|
||||
}
|
||||
|
||||
Json column_arr{Array(v_columns)};
|
||||
std::stringstream ss;
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(column_arr, &str);
|
||||
|
||||
data::CudfAdapter adapter(str);
|
||||
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||
-1);
|
||||
@ -244,9 +242,8 @@ TEST(SimpleCSRSource, FromColumnarSparse) {
|
||||
|
||||
Json column_arr {Array(j_columns)};
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(column_arr, &str);
|
||||
|
||||
{
|
||||
data::CudfAdapter adapter(str);
|
||||
@ -296,9 +293,8 @@ TEST(SimpleDMatrix, FromColumnarSparseBasic) {
|
||||
|
||||
Json column_arr{columns};
|
||||
|
||||
std::stringstream ss;
|
||||
Json::Dump(column_arr, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(column_arr, &str);
|
||||
|
||||
data::CudfAdapter adapter(str);
|
||||
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||
@ -324,9 +320,8 @@ TEST(SimpleDMatrix, FromCupy){
|
||||
int cols = 10;
|
||||
thrust::device_vector< float> data(rows*cols);
|
||||
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
|
||||
std::stringstream ss;
|
||||
Json::Dump(json_array_interface, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(json_array_interface, &str);
|
||||
data::CupyAdapter adapter(str);
|
||||
data::SimpleDMatrix dmat(&adapter, -1, 1);
|
||||
EXPECT_EQ(dmat.Info().num_col_, cols);
|
||||
@ -351,9 +346,8 @@ TEST(SimpleDMatrix, FromCupySparse){
|
||||
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
|
||||
data[1] = std::numeric_limits<float>::quiet_NaN();
|
||||
data[2] = std::numeric_limits<float>::quiet_NaN();
|
||||
std::stringstream ss;
|
||||
Json::Dump(json_array_interface, &ss);
|
||||
std::string str = ss.str();
|
||||
std::string str;
|
||||
Json::Dump(json_array_interface, &str);
|
||||
data::CupyAdapter adapter(str);
|
||||
data::SimpleDMatrix dmat(&adapter, -1, 1);
|
||||
EXPECT_EQ(dmat.Info().num_col_, cols);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user