From 4b0852ee41d0a83aed08f24cf3344da61835f6cd Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 7 Jul 2020 03:07:12 +0800 Subject: [PATCH] Use dmlc stream when URI protocol is not local file. (#5857) --- src/common/io.cc | 71 +++++++++++++++++++++++-------------- src/common/io.h | 11 +++++- tests/cpp/common/test_io.cc | 41 +++++++++++++++++++++ 3 files changed, 96 insertions(+), 27 deletions(-) diff --git a/src/common/io.cc b/src/common/io.cc index cbf11ae0d..01bc1e36b 100644 --- a/src/common/io.cc +++ b/src/common/io.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -93,35 +94,53 @@ void FixedSizeStream::Take(std::string* out) { *out = std::move(buffer_); } -std::string LoadSequentialFile(std::string fname) { - auto OpenErr = [&fname]() { - std::string msg; - msg = "Opening " + fname + " failed: "; - msg += strerror(errno); - LOG(FATAL) << msg; - }; - auto ReadErr = [&fname]() { - std::string msg {"Error in reading file: "}; - msg += fname; - msg += ": "; - msg += strerror(errno); - LOG(FATAL) << msg; - }; +std::string LoadSequentialFile(std::string uri, bool stream) { + auto OpenErr = [&uri]() { + std::string msg; + msg = "Opening " + uri + " failed: "; + msg += strerror(errno); + LOG(FATAL) << msg; + }; + auto parsed = dmlc::io::URI(uri.c_str()); + // Read from file. + if ((parsed.protocol == "file://" || parsed.protocol.length() == 0) && !stream) { + std::string buffer; + // Open in binary mode so that correct file size can be computed with + // seekg(). This accommodates Windows platform: + // https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg + std::ifstream ifs(uri, std::ios_base::binary | std::ios_base::in); + if (!ifs) { + // https://stackoverflow.com/a/17338934 + OpenErr(); + } + + ifs.seekg(0, std::ios_base::end); + const size_t file_size = static_cast(ifs.tellg()); + ifs.seekg(0, std::ios_base::beg); + buffer.resize(file_size + 1); + ifs.read(&buffer[0], file_size); + buffer.back() = '\0'; + + return buffer; + } + + // Read from remote. + std::unique_ptr fs{dmlc::Stream::Create(uri.c_str(), "r")}; std::string buffer; - // Open in binary mode so that correct file size can be computed with seekg(). - // This accommodates Windows platform: - // https://docs.microsoft.com/en-us/cpp/standard-library/basic-istream-class?view=vs-2019#seekg - std::ifstream ifs(fname, std::ios_base::binary | std::ios_base::in); - ifs.seekg(0, std::ios_base::end); - const size_t file_size = static_cast(ifs.tellg()); - ifs.seekg(0, std::ios_base::beg); - buffer.resize(file_size + 1); - ifs.read(&buffer[0], file_size); - buffer.back() = '\0'; - + size_t constexpr kInitialSize = 4096; + size_t size {kInitialSize}, total {0}; + while (true) { + buffer.resize(total + size); + size_t read = fs->Read(&buffer[total], size); + total += read; + if (read < size) { + break; + } + size *= 2; + } + buffer.resize(total); return buffer; } - } // namespace common } // namespace xgboost diff --git a/src/common/io.h b/src/common/io.h index 4ae3fcb02..d9544a6d1 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -75,7 +75,16 @@ class FixedSizeStream : public PeekableInStream { std::string buffer_; }; -std::string LoadSequentialFile(std::string fname); +/*! + * \brief Helper function for loading consecutive file to avoid dmlc Stream when possible. + * + * \param uri URI or file name to file. + * \param stream Use dmlc Stream unconditionally if set to true. Used for running test + * without remote filesystem. + * + * \return File content. + */ +std::string LoadSequentialFile(std::string uri, bool stream = false); inline std::string FileExtension(std::string const& fname) { auto splited = Split(fname, '.'); diff --git a/tests/cpp/common/test_io.cc b/tests/cpp/common/test_io.cc index 957c1748b..192cdc670 100644 --- a/tests/cpp/common/test_io.cc +++ b/tests/cpp/common/test_io.cc @@ -2,6 +2,11 @@ * Copyright (c) by XGBoost Contributors 2019 */ #include +#include + +#include + +#include "../helpers.h" #include "../../../src/common/io.h" namespace xgboost { @@ -39,5 +44,41 @@ TEST(IO, FixedSizeStream) { ASSERT_EQ(huge_buffer, out_buffer); } } + +TEST(IO, LoadSequentialFile) { + EXPECT_THROW(LoadSequentialFile("non-exist"), dmlc::Error); + + dmlc::TemporaryDirectory tempdir; + std::ofstream fout(tempdir.path + "test_file"); + std::string content; + + // Generate a JSON file. + size_t constexpr kRows = 1000, kCols = 100; + std::shared_ptr p_dmat{ + RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)}; + std::unique_ptr learner { Learner::Create({p_dmat}) }; + learner->SetParam("tree_method", "hist"); + learner->Configure(); + + for (int32_t iter = 0; iter < 10; ++iter) { + learner->UpdateOneIter(iter, p_dmat); + } + Json out { Object() }; + learner->SaveModel(&out); + std::string str; + Json::Dump(out, &str); + + std::string tmpfile = tempdir.path + "/model.json"; + { + std::unique_ptr fo( + dmlc::Stream::Create(tmpfile.c_str(), "w")); + fo->Write(str.c_str(), str.size()); + } + + auto loaded = LoadSequentialFile(tmpfile, true); + ASSERT_EQ(loaded, str); + + ASSERT_THROW(LoadSequentialFile("non-exist", true), dmlc::Error); +} } // namespace common } // namespace xgboost