From 2dcb62ddfb2e10d58358195992df829e6d558c83 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 5 Dec 2019 22:15:34 +0800 Subject: [PATCH] Add IO utilities. (#5091) * Add fixed size stream for reading model stream. * Add file extension. --- src/common/io.cc | 80 ++++++++++++++++++++++++++++++++++++- src/common/io.h | 73 +++++++++++++++++++-------------- tests/cpp/common/test_io.cc | 43 ++++++++++++++++++++ 3 files changed, 165 insertions(+), 31 deletions(-) create mode 100644 tests/cpp/common/test_io.cc diff --git a/src/common/io.cc b/src/common/io.cc index 9d3a6b3ad..1f80676a0 100644 --- a/src/common/io.cc +++ b/src/common/io.cc @@ -1,13 +1,15 @@ /*! - * Copyright (c) by Contributors 2019 + * Copyright (c) by XGBoost Contributors 2019 */ #if defined(__unix__) #include #include #include #endif // defined(__unix__) +#include #include #include +#include #include "xgboost/logging.h" #include "io.h" @@ -15,6 +17,81 @@ namespace xgboost { namespace common { +size_t PeekableInStream::Read(void* dptr, size_t size) { + size_t nbuffer = buffer_.length() - buffer_ptr_; + if (nbuffer == 0) return strm_->Read(dptr, size); + if (nbuffer < size) { + std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, nbuffer); + buffer_ptr_ += nbuffer; + return nbuffer + strm_->Read(reinterpret_cast(dptr) + nbuffer, + size - nbuffer); + } else { + std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size); + buffer_ptr_ += size; + return size; + } +} + +size_t PeekableInStream::PeekRead(void* dptr, size_t size) { + size_t nbuffer = buffer_.length() - buffer_ptr_; + if (nbuffer < size) { + buffer_ = buffer_.substr(buffer_ptr_, buffer_.length()); + buffer_ptr_ = 0; + buffer_.resize(size); + size_t nadd = strm_->Read(dmlc::BeginPtr(buffer_) + nbuffer, size - nbuffer); + buffer_.resize(nbuffer + nadd); + std::memcpy(dptr, dmlc::BeginPtr(buffer_), buffer_.length()); + return buffer_.size(); + } else { + std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size); + return size; + } +} + +FixedSizeStream::FixedSizeStream(PeekableInStream* stream) : PeekableInStream(stream), pointer_{0} { + size_t constexpr kInitialSize = 4096; + size_t size {kInitialSize}, total {0}; + buffer_.clear(); + while (true) { + buffer_.resize(size); + size_t read = stream->PeekRead(&buffer_[0], size); + total = read; + if (read < size) { + break; + } + size *= 2; + } + buffer_.resize(total); +} + +size_t FixedSizeStream::Read(void* dptr, size_t size) { + auto read = this->PeekRead(dptr, size); + pointer_ += read; + return read; +} + +size_t FixedSizeStream::PeekRead(void* dptr, size_t size) { + if (size >= buffer_.size() - pointer_) { + std::copy(buffer_.cbegin() + pointer_, buffer_.cend(), reinterpret_cast(dptr)); + return std::distance(buffer_.cbegin() + pointer_, buffer_.cend()); + } else { + auto const beg = buffer_.cbegin() + pointer_; + auto const end = beg + size; + std::copy(beg, end, reinterpret_cast(dptr)); + return std::distance(beg, end); + } +} + +void FixedSizeStream::Seek(size_t pos) { + pointer_ = pos; + CHECK_LE(pointer_, buffer_.size()); +} + +void FixedSizeStream::Take(std::string* out) { + CHECK(out); + *out = std::move(buffer_); +} + std::string LoadSequentialFile(std::string fname) { auto OpenErr = [&fname]() { std::string msg; @@ -59,6 +136,7 @@ std::string LoadSequentialFile(std::string fname) { buffer.resize(fsize + 1); fread(&buffer[0], 1, fsize, f); + buffer.back() = '\0'; fclose(f); #endif // defined(__unix__) return buffer; diff --git a/src/common/io.h b/src/common/io.h index d6072ddd4..193239fbd 100644 --- a/src/common/io.h +++ b/src/common/io.h @@ -13,6 +13,8 @@ #include #include +#include "common.h" + namespace xgboost { namespace common { using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer; @@ -27,36 +29,8 @@ class PeekableInStream : public dmlc::Stream { explicit PeekableInStream(dmlc::Stream* strm) : strm_(strm), buffer_ptr_(0) {} - size_t Read(void* dptr, size_t size) override { - size_t nbuffer = buffer_.length() - buffer_ptr_; - if (nbuffer == 0) return strm_->Read(dptr, size); - if (nbuffer < size) { - std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, nbuffer); - buffer_ptr_ += nbuffer; - return nbuffer + strm_->Read(reinterpret_cast(dptr) + nbuffer, - size - nbuffer); - } else { - std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size); - buffer_ptr_ += size; - return size; - } - } - - size_t PeekRead(void* dptr, size_t size) { - size_t nbuffer = buffer_.length() - buffer_ptr_; - if (nbuffer < size) { - buffer_ = buffer_.substr(buffer_ptr_, buffer_.length()); - buffer_ptr_ = 0; - buffer_.resize(size); - size_t nadd = strm_->Read(dmlc::BeginPtr(buffer_) + nbuffer, size - nbuffer); - buffer_.resize(nbuffer + nadd); - std::memcpy(dptr, dmlc::BeginPtr(buffer_), buffer_.length()); - return buffer_.length(); - } else { - std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size); - return size; - } - } + size_t Read(void* dptr, size_t size) override; + virtual size_t PeekRead(void* dptr, size_t size); void Write(const void* dptr, size_t size) override { LOG(FATAL) << "Not implemented"; @@ -70,10 +44,49 @@ class PeekableInStream : public dmlc::Stream { /*! \brief internal buffer */ std::string buffer_; }; +/*! + * \brief A simple class used to consume `dmlc::Stream' all at once. + * + * With it one can load the rabit checkpoint into a known size string buffer. + */ +class FixedSizeStream : public PeekableInStream { + public: + explicit FixedSizeStream(PeekableInStream* stream); + ~FixedSizeStream() = default; + + size_t Read(void* dptr, size_t size) override; + size_t PeekRead(void* dptr, size_t size) override; + size_t Size() const { return buffer_.size(); } + size_t Tell() const { return pointer_; } + void Seek(size_t pos); + + void Write(const void* dptr, size_t size) override { + LOG(FATAL) << "Not implemented"; + } + + /*! + * \brief Take the buffer from `FixedSizeStream'. The one in `FixedSizeStream' will be + * cleared out. + */ + void Take(std::string* out); + + private: + size_t pointer_; + std::string buffer_; +}; // Optimized for consecutive file loading in unix like systime. std::string LoadSequentialFile(std::string fname); +inline std::string FileExtension(std::string const& fname) { + auto splited = Split(fname, '.'); + if (splited.size() > 1) { + return splited.back(); + } else { + return ""; + } +} + } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_IO_H_ diff --git a/tests/cpp/common/test_io.cc b/tests/cpp/common/test_io.cc new file mode 100644 index 000000000..957c1748b --- /dev/null +++ b/tests/cpp/common/test_io.cc @@ -0,0 +1,43 @@ +/*! + * Copyright (c) by XGBoost Contributors 2019 + */ +#include +#include "../../../src/common/io.h" + +namespace xgboost { +namespace common { +TEST(IO, FileExtension) { + std::string filename {u8"model.json"}; + auto ext = FileExtension(filename); + ASSERT_EQ(ext, u8"json"); +} + +TEST(IO, FixedSizeStream) { + std::string buffer {"This is the content of stream"}; + { + MemoryFixSizeBuffer stream(static_cast(&buffer[0]), buffer.size()); + PeekableInStream peekable(&stream); + FixedSizeStream fixed(&peekable); + + std::string out_buffer; + fixed.Take(&out_buffer); + ASSERT_EQ(buffer, out_buffer); + } + + { + std::string huge_buffer; + for (size_t i = 0; i < 512; i++) { + huge_buffer += buffer; + } + + MemoryFixSizeBuffer stream(static_cast(&huge_buffer[0]), huge_buffer.size()); + PeekableInStream peekable(&stream); + FixedSizeStream fixed(&peekable); + + std::string out_buffer; + fixed.Take(&out_buffer); + ASSERT_EQ(huge_buffer, out_buffer); + } +} +} // namespace common +} // namespace xgboost