From 6dab74689cd5ae6bc94020b29a88691819eeebdb Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 13 Oct 2019 00:09:25 -0400 Subject: [PATCH] Add `SeekEnd` to `MemoryFixSizeBuffer`. (#109) * Don't assert buffer size. --- .gitignore | 2 +- include/rabit/internal/io.h | 15 +++++++++++---- test/.gitignore | 1 - test/cpp/CMakeLists.txt | 5 ++++- test/cpp/test_io.cc | 18 ++++++++++++++++++ 5 files changed, 34 insertions(+), 7 deletions(-) create mode 100644 test/cpp/test_io.cc diff --git a/.gitignore b/.gitignore index eedb5b97c..ad9fedf10 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,4 @@ cmake-build-debug/ # cmake build/ - +compile_commands.json \ No newline at end of file diff --git a/include/rabit/internal/io.h b/include/rabit/internal/io.h index 59494fd19..a492cf705 100644 --- a/include/rabit/internal/io.h +++ b/include/rabit/internal/io.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2014 by Contributors + * Copyright (c) 2014-2019 by Contributors * \file io.h * \brief utilities with different serializable implementations * \author Tianqi Chen @@ -11,6 +11,7 @@ #include #include #include +#include #include "rabit/internal/utils.h" #include "rabit/serializable.h" @@ -20,6 +21,10 @@ namespace utils { typedef dmlc::SeekStream SeekStream; /*! \brief fixed size memory buffer */ struct MemoryFixSizeBuffer : public SeekStream { + public: + // similar to SEEK_END in libc + static size_t constexpr SeekEnd = std::numeric_limits::max(); + public: MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) : p_buffer_(reinterpret_cast(p_buffer)), @@ -28,8 +33,6 @@ struct MemoryFixSizeBuffer : public SeekStream { } virtual ~MemoryFixSizeBuffer(void) {} virtual size_t Read(void *ptr, size_t size) { - utils::Assert(curr_ptr_ + size <= buffer_size_, - "read can not have position excceed buffer length"); size_t nread = std::min(buffer_size_ - curr_ptr_, size); if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread); curr_ptr_ += nread; @@ -43,7 +46,11 @@ struct MemoryFixSizeBuffer : public SeekStream { curr_ptr_ += size; } virtual void Seek(size_t pos) { - curr_ptr_ = static_cast(pos); + if (pos == SeekEnd) { + curr_ptr_ = buffer_size_; + } else { + curr_ptr_ = static_cast(pos); + } } virtual size_t Tell(void) { return curr_ptr_; diff --git a/test/.gitignore b/test/.gitignore index eb87d8f26..8e1ff376e 100644 --- a/test/.gitignore +++ b/test/.gitignore @@ -1,4 +1,3 @@ *.mpi -test_* *_test *_recover diff --git a/test/cpp/CMakeLists.txt b/test/cpp/CMakeLists.txt index e7c15bc3e..979c059a8 100644 --- a/test/cpp/CMakeLists.txt +++ b/test/cpp/CMakeLists.txt @@ -2,7 +2,10 @@ find_package(GTest REQUIRED) add_executable( unit_tests - allreduce_base_test.cc allreduce_robust_test.cc allreduce_mock_test.cc + test_io.cc + allreduce_robust_test.cc + allreduce_base_test.cc + allreduce_mock_test.cc test_main.cpp) target_link_libraries( diff --git a/test/cpp/test_io.cc b/test/cpp/test_io.cc new file mode 100644 index 000000000..0e4b70b1b --- /dev/null +++ b/test/cpp/test_io.cc @@ -0,0 +1,18 @@ +/*! + * Copyright (c) 2019 by Contributors + */ +#include +#include + +#include + +namespace rabit { +TEST(MemoryFixSizeBuffer, Seek) { + size_t constexpr kSize { 64 }; + std::vector memory( kSize ); + utils::MemoryFixSizeBuffer buf(memory.data(), memory.size()); + buf.Seek(utils::MemoryFixSizeBuffer::SeekEnd); + size_t end = buf.Tell(); + ASSERT_EQ(end, kSize); +} +} // namespace rabit