Handle formatted JSON input. (#7953)
This commit is contained in:
parent
d3429f2ff6
commit
13b15e07e8
@ -920,7 +920,6 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||
auto str = common::LoadSequentialFile(fname);
|
||||
CHECK_GE(str.size(), 3); // "{}\0"
|
||||
CHECK_EQ(str[0], '{');
|
||||
CHECK_EQ(str[str.size() - 2], '}');
|
||||
return str;
|
||||
};
|
||||
if (common::FileExtension(fname) == "json") {
|
||||
|
||||
@ -852,12 +852,23 @@ class LearnerIO : public LearnerConfiguration {
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME(jiamingy): Move this out of learner after the old binary model is remove.
|
||||
auto first_non_space = [&](std::string::const_iterator beg, std::string::const_iterator end) {
|
||||
for (auto i = beg; i != end; ++i) {
|
||||
if (!std::isspace(*i)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return end;
|
||||
};
|
||||
|
||||
if (header[0] == '{') { // Dispatch to JSON
|
||||
auto buffer = common::ReadAll(fi, &fp);
|
||||
Json model;
|
||||
if (header[1] == '"') {
|
||||
auto it = first_non_space(buffer.cbegin() + 1, buffer.cend());
|
||||
if (it != buffer.cend() && *it == '"') {
|
||||
model = Json::Load(StringView{buffer});
|
||||
} else if (std::isalpha(header[1])) {
|
||||
} else if (it != buffer.cend() && std::isalpha(*it)) {
|
||||
model = Json::Load(StringView{buffer}, std::ios::binary);
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid model format";
|
||||
|
||||
@ -308,6 +308,15 @@ class TestModels:
|
||||
|
||||
assert old_from_json == old_from_ubj
|
||||
|
||||
raw_json = bst.save_raw(raw_format="json")
|
||||
pretty = json.dumps(json.loads(raw_json), indent=2) + "\n\n"
|
||||
bst.load_model(bytearray(pretty, encoding="ascii"))
|
||||
|
||||
old_from_json = from_jraw.save_raw(raw_format="deprecated")
|
||||
old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated")
|
||||
|
||||
assert old_from_json == old_from_ubj
|
||||
|
||||
@pytest.mark.parametrize("ext", ["json", "ubj"])
|
||||
def test_model_json_io(self, ext: str) -> None:
|
||||
parameters = {"booster": "gbtree", "tree_method": "hist"}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user