Handle formatted JSON input. (#7953)
This commit is contained in:
parent
d3429f2ff6
commit
13b15e07e8
@ -920,7 +920,6 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
|||||||
auto str = common::LoadSequentialFile(fname);
|
auto str = common::LoadSequentialFile(fname);
|
||||||
CHECK_GE(str.size(), 3); // "{}\0"
|
CHECK_GE(str.size(), 3); // "{}\0"
|
||||||
CHECK_EQ(str[0], '{');
|
CHECK_EQ(str[0], '{');
|
||||||
CHECK_EQ(str[str.size() - 2], '}');
|
|
||||||
return str;
|
return str;
|
||||||
};
|
};
|
||||||
if (common::FileExtension(fname) == "json") {
|
if (common::FileExtension(fname) == "json") {
|
||||||
|
|||||||
@ -852,12 +852,23 @@ class LearnerIO : public LearnerConfiguration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME(jiamingy): Move this out of learner after the old binary model is remove.
|
||||||
|
auto first_non_space = [&](std::string::const_iterator beg, std::string::const_iterator end) {
|
||||||
|
for (auto i = beg; i != end; ++i) {
|
||||||
|
if (!std::isspace(*i)) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return end;
|
||||||
|
};
|
||||||
|
|
||||||
if (header[0] == '{') { // Dispatch to JSON
|
if (header[0] == '{') { // Dispatch to JSON
|
||||||
auto buffer = common::ReadAll(fi, &fp);
|
auto buffer = common::ReadAll(fi, &fp);
|
||||||
Json model;
|
Json model;
|
||||||
if (header[1] == '"') {
|
auto it = first_non_space(buffer.cbegin() + 1, buffer.cend());
|
||||||
|
if (it != buffer.cend() && *it == '"') {
|
||||||
model = Json::Load(StringView{buffer});
|
model = Json::Load(StringView{buffer});
|
||||||
} else if (std::isalpha(header[1])) {
|
} else if (it != buffer.cend() && std::isalpha(*it)) {
|
||||||
model = Json::Load(StringView{buffer}, std::ios::binary);
|
model = Json::Load(StringView{buffer}, std::ios::binary);
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL) << "Invalid model format";
|
LOG(FATAL) << "Invalid model format";
|
||||||
|
|||||||
@ -308,6 +308,15 @@ class TestModels:
|
|||||||
|
|
||||||
assert old_from_json == old_from_ubj
|
assert old_from_json == old_from_ubj
|
||||||
|
|
||||||
|
raw_json = bst.save_raw(raw_format="json")
|
||||||
|
pretty = json.dumps(json.loads(raw_json), indent=2) + "\n\n"
|
||||||
|
bst.load_model(bytearray(pretty, encoding="ascii"))
|
||||||
|
|
||||||
|
old_from_json = from_jraw.save_raw(raw_format="deprecated")
|
||||||
|
old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated")
|
||||||
|
|
||||||
|
assert old_from_json == old_from_ubj
|
||||||
|
|
||||||
@pytest.mark.parametrize("ext", ["json", "ubj"])
|
@pytest.mark.parametrize("ext", ["json", "ubj"])
|
||||||
def test_model_json_io(self, ext: str) -> None:
|
def test_model_json_io(self, ext: str) -> None:
|
||||||
parameters = {"booster": "gbtree", "tree_method": "hist"}
|
parameters = {"booster": "gbtree", "tree_method": "hist"}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user