Handle special characters in JSON model dump. (#9474)
This commit is contained in:
parent
f03463c45b
commit
05d7000096
@ -1,16 +1,17 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2015-2019 by Contributors
|
* Copyright 2015-2023 by Contributors
|
||||||
* \file common.cc
|
|
||||||
* \brief Enable all kinds of global variables in common.
|
|
||||||
*/
|
*/
|
||||||
#include <dmlc/thread_local.h>
|
|
||||||
#include <xgboost/logging.h>
|
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "./random.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
#include <dmlc/thread_local.h> // for ThreadLocalStore
|
||||||
namespace common {
|
|
||||||
|
#include <cstdint> // for uint8_t
|
||||||
|
#include <cstdio> // for snprintf, size_t
|
||||||
|
#include <string> // for string
|
||||||
|
|
||||||
|
#include "./random.h" // for GlobalRandomEngine, GlobalRandom
|
||||||
|
|
||||||
|
namespace xgboost::common {
|
||||||
/*! \brief thread local entry for random. */
|
/*! \brief thread local entry for random. */
|
||||||
struct RandomThreadLocalEntry {
|
struct RandomThreadLocalEntry {
|
||||||
/*! \brief the random engine instance. */
|
/*! \brief the random engine instance. */
|
||||||
@ -19,15 +20,43 @@ struct RandomThreadLocalEntry {
|
|||||||
|
|
||||||
using RandomThreadLocalStore = dmlc::ThreadLocalStore<RandomThreadLocalEntry>;
|
using RandomThreadLocalStore = dmlc::ThreadLocalStore<RandomThreadLocalEntry>;
|
||||||
|
|
||||||
GlobalRandomEngine& GlobalRandom() {
|
GlobalRandomEngine &GlobalRandom() { return RandomThreadLocalStore::Get()->engine; }
|
||||||
return RandomThreadLocalStore::Get()->engine;
|
|
||||||
|
void EscapeU8(std::string const &string, std::string *p_buffer) {
|
||||||
|
auto &buffer = *p_buffer;
|
||||||
|
for (size_t i = 0; i < string.length(); i++) {
|
||||||
|
const auto ch = string[i];
|
||||||
|
if (ch == '\\') {
|
||||||
|
if (i < string.size() && string[i + 1] == 'u') {
|
||||||
|
buffer += "\\";
|
||||||
|
} else {
|
||||||
|
buffer += "\\\\";
|
||||||
|
}
|
||||||
|
} else if (ch == '"') {
|
||||||
|
buffer += "\\\"";
|
||||||
|
} else if (ch == '\b') {
|
||||||
|
buffer += "\\b";
|
||||||
|
} else if (ch == '\f') {
|
||||||
|
buffer += "\\f";
|
||||||
|
} else if (ch == '\n') {
|
||||||
|
buffer += "\\n";
|
||||||
|
} else if (ch == '\r') {
|
||||||
|
buffer += "\\r";
|
||||||
|
} else if (ch == '\t') {
|
||||||
|
buffer += "\\t";
|
||||||
|
} else if (static_cast<uint8_t>(ch) <= 0x1f) {
|
||||||
|
// Unit separator
|
||||||
|
char buf[8];
|
||||||
|
snprintf(buf, sizeof buf, "\\u%04x", ch);
|
||||||
|
buffer += buf;
|
||||||
|
} else {
|
||||||
|
buffer += ch;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
int AllVisibleGPUs() {
|
int AllVisibleGPUs() { return 0; }
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
} // namespace common
|
} // namespace xgboost::common
|
||||||
} // namespace xgboost
|
|
||||||
|
|||||||
@ -6,20 +6,19 @@
|
|||||||
#ifndef XGBOOST_COMMON_COMMON_H_
|
#ifndef XGBOOST_COMMON_COMMON_H_
|
||||||
#define XGBOOST_COMMON_COMMON_H_
|
#define XGBOOST_COMMON_COMMON_H_
|
||||||
|
|
||||||
#include <xgboost/base.h>
|
#include <algorithm> // for max
|
||||||
#include <xgboost/logging.h>
|
#include <array> // for array
|
||||||
#include <xgboost/span.h>
|
#include <cmath> // for ceil
|
||||||
|
#include <cstddef> // for size_t
|
||||||
|
#include <cstdint> // for int32_t, int64_t
|
||||||
|
#include <sstream> // for basic_istream, operator<<, istringstream
|
||||||
|
#include <string> // for string, basic_string, getline, char_traits
|
||||||
|
#include <tuple> // for make_tuple
|
||||||
|
#include <utility> // for forward, index_sequence, make_index_sequence
|
||||||
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include <algorithm>
|
#include "xgboost/base.h" // for XGBOOST_DEVICE
|
||||||
#include <exception>
|
#include "xgboost/logging.h" // for LOG, LOG_FATAL, LogMessageFatal
|
||||||
#include <functional>
|
|
||||||
#include <limits>
|
|
||||||
#include <numeric>
|
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
|
||||||
#include <type_traits>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#if defined(__CUDACC__)
|
#if defined(__CUDACC__)
|
||||||
#include <thrust/system/cuda/error.h>
|
#include <thrust/system/cuda/error.h>
|
||||||
@ -52,8 +51,7 @@ inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file,
|
|||||||
#endif // defined(__CUDACC__)
|
#endif // defined(__CUDACC__)
|
||||||
} // namespace dh
|
} // namespace dh
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::common {
|
||||||
namespace common {
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Split a string by delimiter
|
* \brief Split a string by delimiter
|
||||||
* \param s String to be split.
|
* \param s String to be split.
|
||||||
@ -69,19 +67,13 @@ inline std::vector<std::string> Split(const std::string& s, char delim) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EscapeU8(std::string const &string, std::string *p_buffer);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
XGBOOST_DEVICE T Max(T a, T b) {
|
XGBOOST_DEVICE T Max(T a, T b) {
|
||||||
return a < b ? b : a;
|
return a < b ? b : a;
|
||||||
}
|
}
|
||||||
|
|
||||||
// simple routine to convert any data to string
|
|
||||||
template<typename T>
|
|
||||||
inline std::string ToString(const T& data) {
|
|
||||||
std::ostringstream os;
|
|
||||||
os << data;
|
|
||||||
return os.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T1, typename T2>
|
template <typename T1, typename T2>
|
||||||
XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) {
|
XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) {
|
||||||
return static_cast<T1>(std::ceil(static_cast<double>(a) / b));
|
return static_cast<T1>(std::ceil(static_cast<double>(a) / b));
|
||||||
@ -195,6 +187,5 @@ template <typename Indexable>
|
|||||||
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
|
XGBOOST_DEVICE size_t LastOf(size_t group, Indexable const &indptr) {
|
||||||
return indptr[group + 1] - 1;
|
return indptr[group + 1] - 1;
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace xgboost::common
|
||||||
} // namespace xgboost
|
|
||||||
#endif // XGBOOST_COMMON_COMMON_H_
|
#endif // XGBOOST_COMMON_COMMON_H_
|
||||||
|
|||||||
@ -1,23 +1,29 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright (c) by Contributors 2019-2022
|
* Copyright 2019-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
#include <dmlc/endian.h>
|
#include <array> // for array
|
||||||
|
#include <cctype> // for isdigit
|
||||||
|
#include <cmath> // for isinf, isnan
|
||||||
|
#include <cstdio> // for EOF
|
||||||
|
#include <cstdlib> // for size_t, strtof
|
||||||
|
#include <cstring> // for memcpy
|
||||||
|
#include <initializer_list> // for initializer_list
|
||||||
|
#include <iterator> // for distance
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
|
#include <memory> // for allocator
|
||||||
|
#include <sstream> // for operator<<, basic_ostream, operator&, ios, stringstream
|
||||||
|
#include <system_error> // for errc
|
||||||
|
|
||||||
#include <cctype>
|
#include "./math.h" // for CheckNAN
|
||||||
#include <cmath>
|
#include "charconv.h" // for to_chars, NumericLimits, from_chars, to_chars_result
|
||||||
#include <cstddef>
|
#include "common.h" // for EscapeU8
|
||||||
#include <iterator>
|
#include "xgboost/base.h" // for XGBOOST_EXPECT
|
||||||
#include <limits>
|
#include "xgboost/intrusive_ptr.h" // for IntrusivePtr
|
||||||
#include <sstream>
|
#include "xgboost/json_io.h" // for JsonReader, UBJReader, UBJWriter, JsonWriter, ToBigEn...
|
||||||
|
#include "xgboost/logging.h" // for LOG, LOG_FATAL, LogMessageFatal, LogCheck_NE, CHECK
|
||||||
#include "./math.h"
|
#include "xgboost/string_view.h" // for StringView, operator<<
|
||||||
#include "charconv.h"
|
|
||||||
#include "xgboost/base.h"
|
|
||||||
#include "xgboost/json_io.h"
|
|
||||||
#include "xgboost/logging.h"
|
|
||||||
#include "xgboost/string_view.h"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
@ -57,12 +63,12 @@ void JsonWriter::Visit(JsonObject const* obj) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void JsonWriter::Visit(JsonNumber const* num) {
|
void JsonWriter::Visit(JsonNumber const* num) {
|
||||||
char number[NumericLimits<float>::kToCharsSize];
|
std::array<char, NumericLimits<float>::kToCharsSize> number;
|
||||||
auto res = to_chars(number, number + sizeof(number), num->GetNumber());
|
auto res = to_chars(number.data(), number.data() + number.size(), num->GetNumber());
|
||||||
auto end = res.ptr;
|
auto end = res.ptr;
|
||||||
auto ori_size = stream_->size();
|
auto ori_size = stream_->size();
|
||||||
stream_->resize(stream_->size() + end - number);
|
stream_->resize(stream_->size() + end - number.data());
|
||||||
std::memcpy(stream_->data() + ori_size, number, end - number);
|
std::memcpy(stream_->data() + ori_size, number.data(), end - number.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
void JsonWriter::Visit(JsonInteger const* num) {
|
void JsonWriter::Visit(JsonInteger const* num) {
|
||||||
@ -91,35 +97,7 @@ void JsonWriter::Visit(JsonString const* str) {
|
|||||||
std::string buffer;
|
std::string buffer;
|
||||||
buffer += '"';
|
buffer += '"';
|
||||||
auto const& string = str->GetString();
|
auto const& string = str->GetString();
|
||||||
for (size_t i = 0; i < string.length(); i++) {
|
common::EscapeU8(string, &buffer);
|
||||||
const char ch = string[i];
|
|
||||||
if (ch == '\\') {
|
|
||||||
if (i < string.size() && string[i+1] == 'u') {
|
|
||||||
buffer += "\\";
|
|
||||||
} else {
|
|
||||||
buffer += "\\\\";
|
|
||||||
}
|
|
||||||
} else if (ch == '"') {
|
|
||||||
buffer += "\\\"";
|
|
||||||
} else if (ch == '\b') {
|
|
||||||
buffer += "\\b";
|
|
||||||
} else if (ch == '\f') {
|
|
||||||
buffer += "\\f";
|
|
||||||
} else if (ch == '\n') {
|
|
||||||
buffer += "\\n";
|
|
||||||
} else if (ch == '\r') {
|
|
||||||
buffer += "\\r";
|
|
||||||
} else if (ch == '\t') {
|
|
||||||
buffer += "\\t";
|
|
||||||
} else if (static_cast<uint8_t>(ch) <= 0x1f) {
|
|
||||||
// Unit separator
|
|
||||||
char buf[8];
|
|
||||||
snprintf(buf, sizeof buf, "\\u%04x", ch);
|
|
||||||
buffer += buf;
|
|
||||||
} else {
|
|
||||||
buffer += ch;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
buffer += '"';
|
buffer += '"';
|
||||||
|
|
||||||
auto s = stream_->size();
|
auto s = stream_->size();
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for int32_t
|
#include <cstdint> // for int32_t
|
||||||
#include <iterator> // for iterator_traits
|
#include <iterator> // for iterator_traits
|
||||||
|
#include <numeric> // for accumulate
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "common.h" // AssertGPUSupport
|
#include "common.h" // AssertGPUSupport
|
||||||
|
|||||||
@ -797,7 +797,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
bool has_nc {cfg_.find("num_class") != cfg_.cend()};
|
bool has_nc {cfg_.find("num_class") != cfg_.cend()};
|
||||||
// Inject num_class into configuration.
|
// Inject num_class into configuration.
|
||||||
// FIXME(jiamingy): Remove the duplicated parameter in softmax
|
// FIXME(jiamingy): Remove the duplicated parameter in softmax
|
||||||
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
cfg_["num_class"] = std::to_string(mparam_.num_class);
|
||||||
auto& args = *p_args;
|
auto& args = *p_args;
|
||||||
args = {cfg_.cbegin(), cfg_.cend()}; // renew
|
args = {cfg_.cbegin(), cfg_.cend()}; // renew
|
||||||
obj_->Configure(args);
|
obj_->Configure(args);
|
||||||
@ -1076,7 +1076,7 @@ class LearnerIO : public LearnerConfiguration {
|
|||||||
mparam_.major_version = std::get<0>(Version::Self());
|
mparam_.major_version = std::get<0>(Version::Self());
|
||||||
mparam_.minor_version = std::get<1>(Version::Self());
|
mparam_.minor_version = std::get<1>(Version::Self());
|
||||||
|
|
||||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
cfg_["num_feature"] = std::to_string(mparam_.num_feature);
|
||||||
|
|
||||||
auto n = tparam_.__DICT__();
|
auto n = tparam_.__DICT__();
|
||||||
cfg_.insert(n.cbegin(), n.cend());
|
cfg_.insert(n.cbegin(), n.cend());
|
||||||
|
|||||||
@ -398,11 +398,14 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
static std::string const kIndicatorTemplate =
|
static std::string const kIndicatorTemplate =
|
||||||
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
|
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no})ID";
|
||||||
auto split_index = tree[nid].SplitIndex();
|
auto split_index = tree[nid].SplitIndex();
|
||||||
|
auto fname = fmap_.Name(split_index);
|
||||||
|
std::string qfname; // quoted
|
||||||
|
common::EscapeU8(fname, &qfname);
|
||||||
auto result = SuperT::Match(
|
auto result = SuperT::Match(
|
||||||
kIndicatorTemplate,
|
kIndicatorTemplate,
|
||||||
{{"{nid}", std::to_string(nid)},
|
{{"{nid}", std::to_string(nid)},
|
||||||
{"{depth}", std::to_string(depth)},
|
{"{depth}", std::to_string(depth)},
|
||||||
{"{fname}", fmap_.Name(split_index)},
|
{"{fname}", qfname},
|
||||||
{"{yes}", std::to_string(nyes)},
|
{"{yes}", std::to_string(nyes)},
|
||||||
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
||||||
return result;
|
return result;
|
||||||
@ -430,12 +433,14 @@ class JsonGenerator : public TreeGenerator {
|
|||||||
std::string const &template_str, std::string cond,
|
std::string const &template_str, std::string cond,
|
||||||
uint32_t depth) const {
|
uint32_t depth) const {
|
||||||
auto split_index = tree[nid].SplitIndex();
|
auto split_index = tree[nid].SplitIndex();
|
||||||
|
auto fname = split_index < fmap_.Size() ? fmap_.Name(split_index) : std::to_string(split_index);
|
||||||
|
std::string qfname; // quoted
|
||||||
|
common::EscapeU8(fname, &qfname);
|
||||||
std::string const result = SuperT::Match(
|
std::string const result = SuperT::Match(
|
||||||
template_str,
|
template_str,
|
||||||
{{"{nid}", std::to_string(nid)},
|
{{"{nid}", std::to_string(nid)},
|
||||||
{"{depth}", std::to_string(depth)},
|
{"{depth}", std::to_string(depth)},
|
||||||
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
|
{"{fname}", qfname},
|
||||||
std::to_string(split_index)},
|
|
||||||
{"{cond}", cond},
|
{"{cond}", cond},
|
||||||
{"{left}", std::to_string(tree[nid].LeftChild())},
|
{"{left}", std::to_string(tree[nid].LeftChild())},
|
||||||
{"{right}", std::to_string(tree[nid].RightChild())},
|
{"{right}", std::to_string(tree[nid].RightChild())},
|
||||||
|
|||||||
@ -439,6 +439,26 @@ class TestModels:
|
|||||||
'objective': 'multi:softmax'}
|
'objective': 'multi:softmax'}
|
||||||
validate_model(parameters)
|
validate_model(parameters)
|
||||||
|
|
||||||
|
def test_special_model_dump_characters(self):
|
||||||
|
params = {"objective": "reg:squarederror", "max_depth": 3}
|
||||||
|
feature_names = ['"feature 0"', "\tfeature\n1", "feature 2"]
|
||||||
|
X, y, w = tm.make_regression(n_samples=128, n_features=3, use_cupy=False)
|
||||||
|
Xy = xgb.DMatrix(X, label=y, feature_names=feature_names)
|
||||||
|
booster = xgb.train(params, Xy, num_boost_round=3)
|
||||||
|
json_dump = booster.get_dump(dump_format="json")
|
||||||
|
assert len(json_dump) == 3
|
||||||
|
|
||||||
|
def validate(obj: dict) -> None:
|
||||||
|
for k, v in obj.items():
|
||||||
|
if k == "split":
|
||||||
|
assert v in feature_names
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
validate(v)
|
||||||
|
|
||||||
|
for j_tree in json_dump:
|
||||||
|
loaded = json.loads(j_tree)
|
||||||
|
validate(loaded)
|
||||||
|
|
||||||
def test_categorical_model_io(self):
|
def test_categorical_model_io(self):
|
||||||
X, y = tm.make_categorical(256, 16, 71, False)
|
X, y = tm.make_categorical(256, 16, 71, False)
|
||||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user