Use dmlc stream when URI protocol is not local file. (#5857)
This commit is contained in:
parent
0f17e35bce
commit
4b0852ee41
@ -9,6 +9,7 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
|
|
||||||
@ -93,26 +94,27 @@ void FixedSizeStream::Take(std::string* out) {
|
|||||||
*out = std::move(buffer_);
|
*out = std::move(buffer_);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string LoadSequentialFile(std::string fname) {
|
std::string LoadSequentialFile(std::string uri, bool stream) {
|
||||||
auto OpenErr = [&fname]() {
|
auto OpenErr = [&uri]() {
|
||||||
std::string msg;
|
std::string msg;
|
||||||
msg = "Opening " + fname + " failed: ";
|
msg = "Opening " + uri + " failed: ";
|
||||||
msg += strerror(errno);
|
|
||||||
LOG(FATAL) << msg;
|
|
||||||
};
|
|
||||||
auto ReadErr = [&fname]() {
|
|
||||||
std::string msg {"Error in reading file: "};
|
|
||||||
msg += fname;
|
|
||||||
msg += ": ";
|
|
||||||
msg += strerror(errno);
|
msg += strerror(errno);
|
||||||
LOG(FATAL) << msg;
|
LOG(FATAL) << msg;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto parsed = dmlc::io::URI(uri.c_str());
|
||||||
|
// Read from file.
|
||||||
|
if ((parsed.protocol == "file://" || parsed.protocol.length() == 0) && !stream) {
|
||||||
std::string buffer;
|
std::string buffer;
|
||||||
// Open in binary mode so that correct file size can be computed with seekg().
|
// Open in binary mode so that correct file size can be computed with
|
||||||
// This accommodates Windows platform:
|
// seekg(). This accommodates Windows platform:
|
||||||
// https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg
|
// https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg
|
||||||
std::ifstream ifs(fname, std::ios_base::binary | std::ios_base::in);
|
std::ifstream ifs(uri, std::ios_base::binary | std::ios_base::in);
|
||||||
|
if (!ifs) {
|
||||||
|
// https://stackoverflow.com/a/17338934
|
||||||
|
OpenErr();
|
||||||
|
}
|
||||||
|
|
||||||
ifs.seekg(0, std::ios_base::end);
|
ifs.seekg(0, std::ios_base::end);
|
||||||
const size_t file_size = static_cast<size_t>(ifs.tellg());
|
const size_t file_size = static_cast<size_t>(ifs.tellg());
|
||||||
ifs.seekg(0, std::ios_base::beg);
|
ifs.seekg(0, std::ios_base::beg);
|
||||||
@ -123,5 +125,22 @@ std::string LoadSequentialFile(std::string fname) {
|
|||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read from remote.
|
||||||
|
std::unique_ptr<dmlc::Stream> fs{dmlc::Stream::Create(uri.c_str(), "r")};
|
||||||
|
std::string buffer;
|
||||||
|
size_t constexpr kInitialSize = 4096;
|
||||||
|
size_t size {kInitialSize}, total {0};
|
||||||
|
while (true) {
|
||||||
|
buffer.resize(total + size);
|
||||||
|
size_t read = fs->Read(&buffer[total], size);
|
||||||
|
total += read;
|
||||||
|
if (read < size) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
size *= 2;
|
||||||
|
}
|
||||||
|
buffer.resize(total);
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -75,7 +75,16 @@ class FixedSizeStream : public PeekableInStream {
|
|||||||
std::string buffer_;
|
std::string buffer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string LoadSequentialFile(std::string fname);
|
/*!
|
||||||
|
* \brief Helper function for loading consecutive file to avoid dmlc Stream when possible.
|
||||||
|
*
|
||||||
|
* \param uri URI or file name to file.
|
||||||
|
* \param stream Use dmlc Stream unconditionally if set to true. Used for running test
|
||||||
|
* without remote filesystem.
|
||||||
|
*
|
||||||
|
* \return File content.
|
||||||
|
*/
|
||||||
|
std::string LoadSequentialFile(std::string uri, bool stream = false);
|
||||||
|
|
||||||
inline std::string FileExtension(std::string const& fname) {
|
inline std::string FileExtension(std::string const& fname) {
|
||||||
auto splited = Split(fname, '.');
|
auto splited = Split(fname, '.');
|
||||||
|
|||||||
@ -2,6 +2,11 @@
|
|||||||
* Copyright (c) by XGBoost Contributors 2019
|
* Copyright (c) by XGBoost Contributors 2019
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <dmlc/filesystem.h>
|
||||||
|
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#include "../helpers.h"
|
||||||
#include "../../../src/common/io.h"
|
#include "../../../src/common/io.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -39,5 +44,41 @@ TEST(IO, FixedSizeStream) {
|
|||||||
ASSERT_EQ(huge_buffer, out_buffer);
|
ASSERT_EQ(huge_buffer, out_buffer);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(IO, LoadSequentialFile) {
|
||||||
|
EXPECT_THROW(LoadSequentialFile("non-exist"), dmlc::Error);
|
||||||
|
|
||||||
|
dmlc::TemporaryDirectory tempdir;
|
||||||
|
std::ofstream fout(tempdir.path + "test_file");
|
||||||
|
std::string content;
|
||||||
|
|
||||||
|
// Generate a JSON file.
|
||||||
|
size_t constexpr kRows = 1000, kCols = 100;
|
||||||
|
std::shared_ptr<DMatrix> p_dmat{
|
||||||
|
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
|
||||||
|
std::unique_ptr<Learner> learner { Learner::Create({p_dmat}) };
|
||||||
|
learner->SetParam("tree_method", "hist");
|
||||||
|
learner->Configure();
|
||||||
|
|
||||||
|
for (int32_t iter = 0; iter < 10; ++iter) {
|
||||||
|
learner->UpdateOneIter(iter, p_dmat);
|
||||||
|
}
|
||||||
|
Json out { Object() };
|
||||||
|
learner->SaveModel(&out);
|
||||||
|
std::string str;
|
||||||
|
Json::Dump(out, &str);
|
||||||
|
|
||||||
|
std::string tmpfile = tempdir.path + "/model.json";
|
||||||
|
{
|
||||||
|
std::unique_ptr<dmlc::Stream> fo(
|
||||||
|
dmlc::Stream::Create(tmpfile.c_str(), "w"));
|
||||||
|
fo->Write(str.c_str(), str.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto loaded = LoadSequentialFile(tmpfile, true);
|
||||||
|
ASSERT_EQ(loaded, str);
|
||||||
|
|
||||||
|
ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error);
|
||||||
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user