Add SeekEnd to MemoryFixSizeBuffer. (#109)

* Don't assert buffer size.
This commit is contained in:
Jiaming Yuan 2019-10-13 00:09:25 -04:00 committed by GitHub
parent 5d1b613910
commit 6dab74689c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 7 deletions

2
.gitignore vendored
View File

@ -49,4 +49,4 @@ cmake-build-debug/
# cmake
build/
compile_commands.json

View File

@ -1,5 +1,5 @@
/*!
* Copyright (c) 2014 by Contributors
* Copyright (c) 2014-2019 by Contributors
* \file io.h
* \brief utilities with different serializable implementations
* \author Tianqi Chen
@ -11,6 +11,7 @@
#include <cstring>
#include <string>
#include <algorithm>
#include <numeric>
#include "rabit/internal/utils.h"
#include "rabit/serializable.h"
@ -20,6 +21,10 @@ namespace utils {
typedef dmlc::SeekStream SeekStream;
/*! \brief fixed size memory buffer */
struct MemoryFixSizeBuffer : public SeekStream {
public:
// similar to SEEK_END in libc
static size_t constexpr SeekEnd = std::numeric_limits<size_t>::max();
public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)),
@ -28,8 +33,6 @@ struct MemoryFixSizeBuffer : public SeekStream {
}
virtual ~MemoryFixSizeBuffer(void) {}
virtual size_t Read(void *ptr, size_t size) {
utils::Assert(curr_ptr_ + size <= buffer_size_,
"read can not have position excceed buffer length");
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
@ -43,7 +46,11 @@ struct MemoryFixSizeBuffer : public SeekStream {
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
curr_ptr_ = static_cast<size_t>(pos);
if (pos == SeekEnd) {
curr_ptr_ = buffer_size_;
} else {
curr_ptr_ = static_cast<size_t>(pos);
}
}
virtual size_t Tell(void) {
return curr_ptr_;

1
test/.gitignore vendored
View File

@ -1,4 +1,3 @@
*.mpi
test_*
*_test
*_recover

View File

@ -2,7 +2,10 @@ find_package(GTest REQUIRED)
add_executable(
unit_tests
allreduce_base_test.cc allreduce_robust_test.cc allreduce_mock_test.cc
test_io.cc
allreduce_robust_test.cc
allreduce_base_test.cc
allreduce_mock_test.cc
test_main.cpp)
target_link_libraries(

18
test/cpp/test_io.cc Normal file
View File

@ -0,0 +1,18 @@
/*!
* Copyright (c) 2019 by Contributors
*/
#include <gtest/gtest.h>
#include <rabit/internal/io.h>
#include <vector>
namespace rabit {
TEST(MemoryFixSizeBuffer, Seek) {
size_t constexpr kSize { 64 };
std::vector<int32_t> memory( kSize );
utils::MemoryFixSizeBuffer buf(memory.data(), memory.size());
buf.Seek(utils::MemoryFixSizeBuffer::SeekEnd);
size_t end = buf.Tell();
ASSERT_EQ(end, kSize);
}
} // namespace rabit