Add IO utilities. (#5091)
* Add fixed size stream for reading model stream. * Add file extension.
This commit is contained in:
parent
64af1ecf86
commit
2dcb62ddfb
@ -1,13 +1,15 @@
|
||||
/*!
|
||||
* Copyright (c) by Contributors 2019
|
||||
* Copyright (c) by XGBoost Contributors 2019
|
||||
*/
|
||||
#if defined(__unix__)
|
||||
#include <sys/stat.h>
|
||||
#include <fcntl.h>
|
||||
#include <unistd.h>
|
||||
#endif // defined(__unix__)
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "io.h"
|
||||
@ -15,6 +17,81 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
size_t PeekableInStream::Read(void* dptr, size_t size) {
|
||||
size_t nbuffer = buffer_.length() - buffer_ptr_;
|
||||
if (nbuffer == 0) return strm_->Read(dptr, size);
|
||||
if (nbuffer < size) {
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, nbuffer);
|
||||
buffer_ptr_ += nbuffer;
|
||||
return nbuffer + strm_->Read(reinterpret_cast<char*>(dptr) + nbuffer,
|
||||
size - nbuffer);
|
||||
} else {
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
|
||||
buffer_ptr_ += size;
|
||||
return size;
|
||||
}
|
||||
}
|
||||
|
||||
size_t PeekableInStream::PeekRead(void* dptr, size_t size) {
|
||||
size_t nbuffer = buffer_.length() - buffer_ptr_;
|
||||
if (nbuffer < size) {
|
||||
buffer_ = buffer_.substr(buffer_ptr_, buffer_.length());
|
||||
buffer_ptr_ = 0;
|
||||
buffer_.resize(size);
|
||||
size_t nadd = strm_->Read(dmlc::BeginPtr(buffer_) + nbuffer, size - nbuffer);
|
||||
buffer_.resize(nbuffer + nadd);
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_), buffer_.length());
|
||||
return buffer_.size();
|
||||
} else {
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
|
||||
return size;
|
||||
}
|
||||
}
|
||||
|
||||
FixedSizeStream::FixedSizeStream(PeekableInStream* stream) : PeekableInStream(stream), pointer_{0} {
|
||||
size_t constexpr kInitialSize = 4096;
|
||||
size_t size {kInitialSize}, total {0};
|
||||
buffer_.clear();
|
||||
while (true) {
|
||||
buffer_.resize(size);
|
||||
size_t read = stream->PeekRead(&buffer_[0], size);
|
||||
total = read;
|
||||
if (read < size) {
|
||||
break;
|
||||
}
|
||||
size *= 2;
|
||||
}
|
||||
buffer_.resize(total);
|
||||
}
|
||||
|
||||
size_t FixedSizeStream::Read(void* dptr, size_t size) {
|
||||
auto read = this->PeekRead(dptr, size);
|
||||
pointer_ += read;
|
||||
return read;
|
||||
}
|
||||
|
||||
size_t FixedSizeStream::PeekRead(void* dptr, size_t size) {
|
||||
if (size >= buffer_.size() - pointer_) {
|
||||
std::copy(buffer_.cbegin() + pointer_, buffer_.cend(), reinterpret_cast<char*>(dptr));
|
||||
return std::distance(buffer_.cbegin() + pointer_, buffer_.cend());
|
||||
} else {
|
||||
auto const beg = buffer_.cbegin() + pointer_;
|
||||
auto const end = beg + size;
|
||||
std::copy(beg, end, reinterpret_cast<char*>(dptr));
|
||||
return std::distance(beg, end);
|
||||
}
|
||||
}
|
||||
|
||||
void FixedSizeStream::Seek(size_t pos) {
|
||||
pointer_ = pos;
|
||||
CHECK_LE(pointer_, buffer_.size());
|
||||
}
|
||||
|
||||
void FixedSizeStream::Take(std::string* out) {
|
||||
CHECK(out);
|
||||
*out = std::move(buffer_);
|
||||
}
|
||||
|
||||
std::string LoadSequentialFile(std::string fname) {
|
||||
auto OpenErr = [&fname]() {
|
||||
std::string msg;
|
||||
@ -59,6 +136,7 @@ std::string LoadSequentialFile(std::string fname) {
|
||||
|
||||
buffer.resize(fsize + 1);
|
||||
fread(&buffer[0], 1, fsize, f);
|
||||
buffer.back() = '\0';
|
||||
fclose(f);
|
||||
#endif // defined(__unix__)
|
||||
return buffer;
|
||||
|
||||
@ -13,6 +13,8 @@
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
using MemoryFixSizeBuffer = rabit::utils::MemoryFixSizeBuffer;
|
||||
@ -27,36 +29,8 @@ class PeekableInStream : public dmlc::Stream {
|
||||
explicit PeekableInStream(dmlc::Stream* strm)
|
||||
: strm_(strm), buffer_ptr_(0) {}
|
||||
|
||||
size_t Read(void* dptr, size_t size) override {
|
||||
size_t nbuffer = buffer_.length() - buffer_ptr_;
|
||||
if (nbuffer == 0) return strm_->Read(dptr, size);
|
||||
if (nbuffer < size) {
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, nbuffer);
|
||||
buffer_ptr_ += nbuffer;
|
||||
return nbuffer + strm_->Read(reinterpret_cast<char*>(dptr) + nbuffer,
|
||||
size - nbuffer);
|
||||
} else {
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
|
||||
buffer_ptr_ += size;
|
||||
return size;
|
||||
}
|
||||
}
|
||||
|
||||
size_t PeekRead(void* dptr, size_t size) {
|
||||
size_t nbuffer = buffer_.length() - buffer_ptr_;
|
||||
if (nbuffer < size) {
|
||||
buffer_ = buffer_.substr(buffer_ptr_, buffer_.length());
|
||||
buffer_ptr_ = 0;
|
||||
buffer_.resize(size);
|
||||
size_t nadd = strm_->Read(dmlc::BeginPtr(buffer_) + nbuffer, size - nbuffer);
|
||||
buffer_.resize(nbuffer + nadd);
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_), buffer_.length());
|
||||
return buffer_.length();
|
||||
} else {
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
|
||||
return size;
|
||||
}
|
||||
}
|
||||
size_t Read(void* dptr, size_t size) override;
|
||||
virtual size_t PeekRead(void* dptr, size_t size);
|
||||
|
||||
void Write(const void* dptr, size_t size) override {
|
||||
LOG(FATAL) << "Not implemented";
|
||||
@ -70,10 +44,49 @@ class PeekableInStream : public dmlc::Stream {
|
||||
/*! \brief internal buffer */
|
||||
std::string buffer_;
|
||||
};
|
||||
/*!
|
||||
* \brief A simple class used to consume `dmlc::Stream' all at once.
|
||||
*
|
||||
* With it one can load the rabit checkpoint into a known size string buffer.
|
||||
*/
|
||||
class FixedSizeStream : public PeekableInStream {
|
||||
public:
|
||||
explicit FixedSizeStream(PeekableInStream* stream);
|
||||
~FixedSizeStream() = default;
|
||||
|
||||
size_t Read(void* dptr, size_t size) override;
|
||||
size_t PeekRead(void* dptr, size_t size) override;
|
||||
size_t Size() const { return buffer_.size(); }
|
||||
size_t Tell() const { return pointer_; }
|
||||
void Seek(size_t pos);
|
||||
|
||||
void Write(const void* dptr, size_t size) override {
|
||||
LOG(FATAL) << "Not implemented";
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Take the buffer from `FixedSizeStream'. The one in `FixedSizeStream' will be
|
||||
* cleared out.
|
||||
*/
|
||||
void Take(std::string* out);
|
||||
|
||||
private:
|
||||
size_t pointer_;
|
||||
std::string buffer_;
|
||||
};
|
||||
|
||||
// Optimized for consecutive file loading in unix like systime.
|
||||
std::string LoadSequentialFile(std::string fname);
|
||||
|
||||
inline std::string FileExtension(std::string const& fname) {
|
||||
auto splited = Split(fname, '.');
|
||||
if (splited.size() > 1) {
|
||||
return splited.back();
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_IO_H_
|
||||
|
||||
43
tests/cpp/common/test_io.cc
Normal file
43
tests/cpp/common/test_io.cc
Normal file
@ -0,0 +1,43 @@
|
||||
/*!
|
||||
* Copyright (c) by XGBoost Contributors 2019
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include "../../../src/common/io.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
TEST(IO, FileExtension) {
|
||||
std::string filename {u8"model.json"};
|
||||
auto ext = FileExtension(filename);
|
||||
ASSERT_EQ(ext, u8"json");
|
||||
}
|
||||
|
||||
TEST(IO, FixedSizeStream) {
|
||||
std::string buffer {"This is the content of stream"};
|
||||
{
|
||||
MemoryFixSizeBuffer stream(static_cast<void *>(&buffer[0]), buffer.size());
|
||||
PeekableInStream peekable(&stream);
|
||||
FixedSizeStream fixed(&peekable);
|
||||
|
||||
std::string out_buffer;
|
||||
fixed.Take(&out_buffer);
|
||||
ASSERT_EQ(buffer, out_buffer);
|
||||
}
|
||||
|
||||
{
|
||||
std::string huge_buffer;
|
||||
for (size_t i = 0; i < 512; i++) {
|
||||
huge_buffer += buffer;
|
||||
}
|
||||
|
||||
MemoryFixSizeBuffer stream(static_cast<void*>(&huge_buffer[0]), huge_buffer.size());
|
||||
PeekableInStream peekable(&stream);
|
||||
FixedSizeStream fixed(&peekable);
|
||||
|
||||
std::string out_buffer;
|
||||
fixed.Take(&out_buffer);
|
||||
ASSERT_EQ(huge_buffer, out_buffer);
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
Loading…
x
Reference in New Issue
Block a user