Handle formatted JSON input. (#7953)

This commit is contained in:
Jiaming Yuan 2022-06-01 16:20:58 +08:00 committed by GitHub
parent d3429f2ff6
commit 13b15e07e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 3 deletions

View File

@ -920,7 +920,6 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
auto str = common::LoadSequentialFile(fname);
CHECK_GE(str.size(), 3); // "{}\0"
CHECK_EQ(str[0], '{');
CHECK_EQ(str[str.size() - 2], '}');
return str;
};
if (common::FileExtension(fname) == "json") {

View File

@ -852,12 +852,23 @@ class LearnerIO : public LearnerConfiguration {
}
}
// FIXME(jiamingy): Move this out of learner after the old binary model is remove.
auto first_non_space = [&](std::string::const_iterator beg, std::string::const_iterator end) {
for (auto i = beg; i != end; ++i) {
if (!std::isspace(*i)) {
return i;
}
}
return end;
};
if (header[0] == '{') { // Dispatch to JSON
auto buffer = common::ReadAll(fi, &fp);
Json model;
if (header[1] == '"') {
auto it = first_non_space(buffer.cbegin() + 1, buffer.cend());
if (it != buffer.cend() && *it == '"') {
model = Json::Load(StringView{buffer});
} else if (std::isalpha(header[1])) {
} else if (it != buffer.cend() && std::isalpha(*it)) {
model = Json::Load(StringView{buffer}, std::ios::binary);
} else {
LOG(FATAL) << "Invalid model format";

View File

@ -308,6 +308,15 @@ class TestModels:
assert old_from_json == old_from_ubj
raw_json = bst.save_raw(raw_format="json")
pretty = json.dumps(json.loads(raw_json), indent=2) + "\n\n"
bst.load_model(bytearray(pretty, encoding="ascii"))
old_from_json = from_jraw.save_raw(raw_format="deprecated")
old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated")
assert old_from_json == old_from_ubj
@pytest.mark.parametrize("ext", ["json", "ubj"])
def test_model_json_io(self, ext: str) -> None:
parameters = {"booster": "gbtree", "tree_method": "hist"}