Add Json integer, remove specialization. (#4739)

This commit is contained in:
Jiaming Yuan
2019-08-06 03:10:49 -04:00
committed by GitHub
parent 9c469b3844
commit 2a4df8e29f
5 changed files with 314 additions and 221 deletions

View File

@@ -69,6 +69,12 @@
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z))
#endif // GLIBC VERSION
#if defined(__GNUC__)
#define XGBOOST_EXPECT(cond, ret) __builtin_expect((cond), (ret))
#else
#define XGBOOST_EXPECT(cond, ret) (cond)
#endif // defined(__GNUC__)
/*!
* \brief Tag function as usable by device
*/

View File

@@ -4,8 +4,9 @@
#ifndef XGBOOST_JSON_H_
#define XGBOOST_JSON_H_
#include <xgboost/logging.h>
#include <dmlc/io.h>
#include <xgboost/logging.h>
#include <string>
#include <map>
@@ -29,7 +30,6 @@ class Value {
Integer,
Object, // std::map
Array, // std::vector
Raw,
Boolean,
Null
};
@@ -63,9 +63,9 @@ T* Cast(U* value) {
if (IsA<T>(value)) {
return dynamic_cast<T*>(value);
} else {
throw std::runtime_error(
"Invalid cast, from " + value->TypeStr() + " to " + T().TypeStr());
LOG(FATAL) << "Invalid cast, from " + value->TypeStr() + " to " + T().TypeStr();
}
return dynamic_cast<T*>(value); // supress compiler warning.
}
class JsonString : public Value {
@@ -123,32 +123,6 @@ class JsonArray : public Value {
}
};
class JsonRaw : public Value {
std::string str_;
public:
explicit JsonRaw(std::string&& str) :
Value(ValueKind::Raw),
str_{std::move(str)}{} // NOLINT
JsonRaw() : Value(ValueKind::Raw) {}
std::string const& getRaw() && { return str_; }
std::string const& getRaw() const & { return str_; }
std::string& getRaw() & { return str_; }
void Save(JsonWriter* writer) override;
Json& operator[](std::string const & key) override;
Json& operator[](int ind) override;
bool operator==(Value const& rhs) const override;
Value& operator=(Value const& rhs) override;
static bool isClassOf(Value const* value) {
return value->Type() == ValueKind::Raw;
}
};
class JsonObject : public Value {
std::map<std::string, Json> object_;
@@ -185,7 +159,9 @@ class JsonNumber : public Value {
public:
JsonNumber() : Value(ValueKind::Number) {}
JsonNumber(double value) : Value(ValueKind::Number) { // NOLINT
template <typename FloatT,
typename std::enable_if<std::is_same<FloatT, Float>::value>::type* = nullptr>
JsonNumber(FloatT value) : Value(ValueKind::Number) { // NOLINT
number_ = value;
}
@@ -198,6 +174,7 @@ class JsonNumber : public Value {
Float const& getNumber() const & { return number_; }
Float& getNumber() & { return number_; }
bool operator==(Value const& rhs) const override;
Value& operator=(Value const& rhs) override;
@@ -206,6 +183,35 @@ class JsonNumber : public Value {
}
};
class JsonInteger : public Value {
public:
using Int = int64_t;
private:
Int integer_;
public:
JsonInteger() : Value(ValueKind::Integer), integer_{0} {} // NOLINT
template <typename IntT,
typename std::enable_if<std::is_same<IntT, Int>::value>::type* = nullptr>
JsonInteger(IntT value) : Value(ValueKind::Integer), integer_{value} {} // NOLINT
Json& operator[](std::string const & key) override;
Json& operator[](int ind) override;
bool operator==(Value const& rhs) const override;
Value& operator=(Value const& rhs) override;
Int const& getInteger() && { return integer_; }
Int const& getInteger() const & { return integer_; }
Int& getInteger() & { return integer_; }
void Save(JsonWriter* writer) override;
static bool isClassOf(Value const* value) {
return value->Type() == ValueKind::Integer;
}
};
class JsonNull : public Value {
public:
JsonNull() : Value(ValueKind::Null) {}
@@ -256,15 +262,16 @@ class JsonBoolean : public Value {
};
struct StringView {
char const* str_;
using CharT = char; // unsigned char
CharT const* str_;
size_t size_;
public:
StringView() = default;
StringView(char const* str, size_t size) : str_{str}, size_{size} {}
StringView(CharT const* str, size_t size) : str_{str}, size_{size} {}
char const& operator[](size_t p) const { return str_[p]; }
char const& at(size_t p) const { // NOLINT
CharT const& operator[](size_t p) const { return str_[p]; }
CharT const& at(size_t p) const { // NOLINT
CHECK_LT(p, size_);
return str_[p];
}
@@ -302,7 +309,7 @@ class Json {
public:
/*! \brief Load a Json object from string. */
static Json Load(StringView str, bool ignore_specialization = false);
static Json Load(StringView str);
/*! \brief Pass your own JsonReader. */
static Json Load(JsonReader* reader);
/*! \brief Dump json into stream. */
@@ -319,6 +326,13 @@ class Json {
return *this;
}
// integer
explicit Json(JsonInteger integer) : ptr_{new JsonInteger(integer)} {}
Json& operator=(JsonInteger integer) {
ptr_.reset(new JsonInteger(std::move(integer)));
return *this;
}
// array
explicit Json(JsonArray list) :
ptr_ {new JsonArray(std::move(list))} {}
@@ -327,14 +341,6 @@ class Json {
return *this;
}
// raw
explicit Json(JsonRaw str) :
ptr_{new JsonRaw(std::move(str))} {}
Json& operator=(JsonRaw str) {
ptr_.reset(new JsonRaw(std::move(str)));
return *this;
}
// object
explicit Json(JsonObject object) :
ptr_{new JsonObject(std::move(object))} {}
@@ -410,10 +416,24 @@ JsonNumber::Float& GetImpl(T& val) { // NOLINT
template <typename T,
typename std::enable_if<
std::is_same<T, JsonNumber const>::value>::type* = nullptr>
double const& GetImpl(T& val) { // NOLINT
JsonNumber::Float const& GetImpl(T& val) { // NOLINT
return val.getNumber();
}
// Integer
template <typename T,
typename std::enable_if<
std::is_same<T, JsonInteger>::value>::type* = nullptr>
JsonInteger::Int& GetImpl(T& val) { // NOLINT
return val.getInteger();
}
template <typename T,
typename std::enable_if<
std::is_same<T, JsonInteger const>::value>::type* = nullptr>
JsonInteger::Int const& GetImpl(T& val) { // NOLINT
return val.getInteger();
}
// String
template <typename T,
typename std::enable_if<
@@ -442,19 +462,6 @@ bool const& GetImpl(T& val) { // NOLINT
return val.getBoolean();
}
template <typename T,
typename std::enable_if<
std::is_same<T, JsonRaw>::value>::type* = nullptr>
std::string& GetImpl(T& val) { // NOLINT
return val.getRaw();
}
template <typename T,
typename std::enable_if<
std::is_same<T, JsonRaw const>::value>::type* = nullptr>
std::string const& GetImpl(T& val) { // NOLINT
return val.getRaw();
}
// Array
template <typename T,
typename std::enable_if<
@@ -502,10 +509,10 @@ auto get(U& json) -> decltype(detail::GetImpl(*Cast<T>(&json.GetValue())))& { //
using Object = JsonObject;
using Array = JsonArray;
using Number = JsonNumber;
using Integer = JsonInteger;
using Boolean = JsonBoolean;
using String = JsonString;
using Null = JsonNull;
using Raw = JsonRaw;
// Utils tailored for XGBoost.
@@ -518,13 +525,14 @@ Object toJson(dmlc::Parameter<Type> const& param) {
return obj;
}
inline std::map<std::string, std::string> fromJson(std::map<std::string, Json> const& param) {
std::map<std::string, std::string> res;
for (auto const& kv : param) {
res[kv.first] = get<String const>(kv.second);
template <typename Type>
void fromJson(Json const& obj, dmlc::Parameter<Type>* param) {
auto const& j_param = get<Object const>(obj);
std::map<std::string, std::string> m;
for (auto const& kv : j_param) {
m[kv.first] = get<String const>(kv.second);
}
return res;
param->InitAllowUnknown(m);
}
} // namespace xgboost
#endif // XGBOOST_JSON_H_

View File

@@ -22,50 +22,15 @@ class FixedPrecisionStreamContainer : public std::basic_stringstream<
public:
FixedPrecisionStreamContainer() {
this->precision(std::numeric_limits<Number::Float>::max_digits10);
this->imbue(std::locale("C"));
this->setf(std::ios::scientific);
}
};
using FixedPrecisionStream = FixedPrecisionStreamContainer<std::allocator<char>>;
/*
* \brief An reader that can be specialised.
*
* Why specialization?
*
* First of all, we don't like specialization. This is purely for performance concern.
* Distributed environment freqently serializes model so at some point this could be a
* bottle neck for training performance. There are many other techniques for obtaining
* better performance, but all of them requires implementing thier own allocaltor(s),
* using simd instructions. And few of them can provide a easy to modify structure
* since they assumes a fixed memory layout.
*
* In XGBoost we provide specialized logic for parsing/writing tree models and linear
* models, where dense numeric values is presented, including weights, node ids etc.
*
* Plan for removing the specialization:
*
* We plan to upstream this implementaion into DMLC as it matures. For XGBoost, most of
* the time spent in load/dump is actually `sprintf`.
*
* To enable specialization, register a keyword that corresponds to
* key in Json object. For example in:
*
* \code
* { "key": {...} }
* \endcode
*
* To add special logic for parsing {...}, one can call:
*
* \code
* JsonReader::registry("key", [](StringView str, size_t* pos){ ... return JsonRaw(...); });
* \endcode
*
* Where str is a view of entire input string, while pos is a pointer to current position.
* The function must return a raw object. Later after obtaining a parsed object, say
* `Json obj`, you can obtain * the raw object by calling `obj["key"]' then perform the
* specialized parsing on it.
*
* See `LinearSelectRaw` and `LinearReader` in combination as an example.
* \brief A json reader, currently error checking and utf-8 is not fully supported.
*/
class JsonReader {
protected:
@@ -77,17 +42,19 @@ class JsonReader {
public:
SourceLocation() : pos_(0) {}
explicit SourceLocation(size_t pos) : pos_{pos} {}
size_t Pos() const { return pos_; }
SourceLocation& Forward(char c = 0) {
SourceLocation& Forward() {
pos_++;
return *this;
}
SourceLocation& Forward(uint32_t n) {
pos_ += n;
return *this;
}
} cursor_;
StringView raw_str_;
bool ignore_specialization_;
protected:
void SkipSpaces();
@@ -140,32 +107,13 @@ class JsonReader {
Json Parse();
private:
using Fn = std::function<Json (StringView, size_t*)>;
public:
explicit JsonReader(StringView str, bool ignore = false) :
raw_str_{str},
ignore_specialization_{ignore} {}
explicit JsonReader(StringView str, size_t pos, bool ignore = false) :
cursor_{pos},
raw_str_{str},
ignore_specialization_{ignore} {}
explicit JsonReader(StringView str) :
raw_str_{str} {}
virtual ~JsonReader() = default;
Json Load();
static std::map<std::string, Fn>& getRegistry() {
static std::map<std::string, Fn> set;
return set;
}
static std::map<std::string, Fn> const& registry(
std::string const& key, Fn fn) {
getRegistry()[key] = fn;
return getRegistry();
}
};
class JsonWriter {
@@ -207,7 +155,7 @@ class JsonWriter {
virtual void Visit(JsonArray const* arr);
virtual void Visit(JsonObject const* obj);
virtual void Visit(JsonNumber const* num);
virtual void Visit(JsonRaw const* raw);
virtual void Visit(JsonInteger const* num);
virtual void Visit(JsonNull const* null);
virtual void Visit(JsonString const* str);
virtual void Visit(JsonBoolean const* boolean);