Drop support for loading remote files. (#9504)

This commit is contained in:
Jiaming Yuan
2023-08-21 23:34:05 +08:00
committed by GitHub
parent d779a11af9
commit 044fea1281
12 changed files with 43 additions and 112 deletions

View File

@@ -1220,12 +1220,12 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
return str;
};
if (common::FileExtension(fname) == "json") {
auto str = read_file();
Json in{Json::Load(StringView{str})};
auto buffer = read_file();
Json in{Json::Load(StringView{buffer.data(), buffer.size()})};
static_cast<Learner*>(handle)->LoadModel(in);
} else if (common::FileExtension(fname) == "ubj") {
auto str = read_file();
Json in = Json::Load(StringView{str}, std::ios::binary);
auto buffer = read_file();
Json in = Json::Load(StringView{buffer.data(), buffer.size()}, std::ios::binary);
static_cast<Learner *>(handle)->LoadModel(in);
} else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));

View File

@@ -345,10 +345,10 @@ class CLI {
void LoadModel(std::string const& path, Learner* learner) const {
if (common::FileExtension(path) == "json") {
auto str = common::LoadSequentialFile(path);
CHECK_GT(str.size(), 2);
CHECK_EQ(str[0], '{');
Json in{Json::Load({str.c_str(), str.size()})};
auto buffer = common::LoadSequentialFile(path);
CHECK_GT(buffer.size(), 2);
CHECK_EQ(buffer[0], '{');
Json in{Json::Load({buffer.data(), buffer.size()})};
learner->LoadModel(in);
} else {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(path.c_str(), "r"));

View File

@@ -139,7 +139,7 @@ auto SystemErrorMsg() {
}
} // anonymous namespace
std::string LoadSequentialFile(std::string uri, bool stream) {
std::vector<char> LoadSequentialFile(std::string uri) {
auto OpenErr = [&uri]() {
std::string msg;
msg = "Opening " + uri + " failed: ";
@@ -148,44 +148,20 @@ std::string LoadSequentialFile(std::string uri, bool stream) {
};
auto parsed = dmlc::io::URI(uri.c_str());
CHECK((parsed.protocol == "file://" || parsed.protocol.length() == 0))
<< "Only local file is supported.";
// Read from file.
if ((parsed.protocol == "file://" || parsed.protocol.length() == 0) && !stream) {
std::string buffer;
// Open in binary mode so that correct file size can be computed with
// seekg(). This accommodates Windows platform:
// https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg
auto path = std::filesystem::weakly_canonical(std::filesystem::u8path(uri));
std::ifstream ifs(path, std::ios_base::binary | std::ios_base::in);
if (!ifs) {
// https://stackoverflow.com/a/17338934
OpenErr();
}
ifs.seekg(0, std::ios_base::end);
const size_t file_size = static_cast<size_t>(ifs.tellg());
ifs.seekg(0, std::ios_base::beg);
buffer.resize(file_size + 1);
ifs.read(&buffer[0], file_size);
buffer.back() = '\0';
return buffer;
auto path = std::filesystem::weakly_canonical(std::filesystem::u8path(uri));
std::ifstream ifs(path, std::ios_base::binary | std::ios_base::in);
if (!ifs) {
// https://stackoverflow.com/a/17338934
OpenErr();
}
// Read from remote.
std::unique_ptr<dmlc::Stream> fs{dmlc::Stream::Create(uri.c_str(), "r")};
std::string buffer;
size_t constexpr kInitialSize = 4096;
size_t size {kInitialSize}, total {0};
while (true) {
buffer.resize(total + size);
size_t read = fs->Read(&buffer[total], size);
total += read;
if (read < size) {
break;
}
size *= 2;
}
buffer.resize(total);
auto file_size = std::filesystem::file_size(path);
std::vector<char> buffer(file_size);
ifs.read(&buffer[0], file_size);
return buffer;
}

View File

@@ -84,16 +84,14 @@ class FixedSizeStream : public PeekableInStream {
std::string buffer_;
};
/*!
* \brief Helper function for loading consecutive file to avoid dmlc Stream when possible.
/**
* @brief Helper function for loading consecutive file.
*
* \param uri URI or file name to file.
* \param stream Use dmlc Stream unconditionally if set to true. Used for running test
* without remote filesystem.
* @param uri URI or file name to file.
*
* \return File content.
* @return File content.
*/
std::string LoadSequentialFile(std::string uri, bool stream = false);
std::vector<char> LoadSequentialFile(std::string uri);
/**
* \brief Get file extension from file name.