diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index d72eb077b..94c9109bc 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -920,7 +920,6 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) { auto str = common::LoadSequentialFile(fname); CHECK_GE(str.size(), 3); // "{}\0" CHECK_EQ(str[0], '{'); - CHECK_EQ(str[str.size() - 2], '}'); return str; }; if (common::FileExtension(fname) == "json") { diff --git a/src/learner.cc b/src/learner.cc index 5d7d067e7..c27a0f514 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -852,12 +852,23 @@ class LearnerIO : public LearnerConfiguration { } } + // FIXME(jiamingy): Move this out of learner after the old binary model is remove. + auto first_non_space = [&](std::string::const_iterator beg, std::string::const_iterator end) { + for (auto i = beg; i != end; ++i) { + if (!std::isspace(*i)) { + return i; + } + } + return end; + }; + if (header[0] == '{') { // Dispatch to JSON auto buffer = common::ReadAll(fi, &fp); Json model; - if (header[1] == '"') { + auto it = first_non_space(buffer.cbegin() + 1, buffer.cend()); + if (it != buffer.cend() && *it == '"') { model = Json::Load(StringView{buffer}); - } else if (std::isalpha(header[1])) { + } else if (it != buffer.cend() && std::isalpha(*it)) { model = Json::Load(StringView{buffer}, std::ios::binary); } else { LOG(FATAL) << "Invalid model format"; diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 510aec506..82d0096cf 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -308,6 +308,15 @@ class TestModels: assert old_from_json == old_from_ubj + raw_json = bst.save_raw(raw_format="json") + pretty = json.dumps(json.loads(raw_json), indent=2) + "\n\n" + bst.load_model(bytearray(pretty, encoding="ascii")) + + old_from_json = from_jraw.save_raw(raw_format="deprecated") + old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated") + + assert old_from_json == old_from_ubj + @pytest.mark.parametrize("ext", ["json", "ubj"]) def test_model_json_io(self, ext: str) -> None: parameters = {"booster": "gbtree", "tree_method": "hist"}