[LEARNER] refactor learner
This commit is contained in:
@@ -9,12 +9,67 @@
|
||||
#define XGBOOST_COMMON_IO_H_
|
||||
|
||||
#include <dmlc/io.h>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include "./sync.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
|
||||
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
|
||||
|
||||
/*!
|
||||
* \brief Input stream that support additional PeekRead
|
||||
* operation, besides read.
|
||||
*/
|
||||
class PeekableInStream : public dmlc::Stream {
|
||||
public:
|
||||
explicit PeekableInStream(dmlc::Stream* strm)
|
||||
: strm_(strm), buffer_ptr_(0) {}
|
||||
|
||||
size_t Read(void* dptr, size_t size) override {
|
||||
size_t nbuffer = buffer_.length() - buffer_ptr_;
|
||||
if (nbuffer == 0) return strm_->Read(dptr, size);
|
||||
if (nbuffer < size) {
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, nbuffer);
|
||||
buffer_ptr_ += nbuffer;
|
||||
return nbuffer + strm_->Read(reinterpret_cast<char*>(dptr) + nbuffer,
|
||||
size - nbuffer);
|
||||
} else {
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
|
||||
buffer_ptr_ += size;
|
||||
return size;
|
||||
}
|
||||
}
|
||||
|
||||
size_t PeekRead(void* dptr, size_t size) {
|
||||
size_t nbuffer = buffer_.length() - buffer_ptr_;
|
||||
if (nbuffer < size) {
|
||||
buffer_ = buffer_.substr(buffer_ptr_, buffer_.length());
|
||||
buffer_ptr_ = 0;
|
||||
buffer_.resize(size);
|
||||
size_t nadd = strm_->Read(dmlc::BeginPtr(buffer_) + nbuffer, size - nbuffer);
|
||||
buffer_.resize(nbuffer + nadd);
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_), buffer_.length());
|
||||
return buffer_.length();
|
||||
} else {
|
||||
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
|
||||
return size;
|
||||
}
|
||||
}
|
||||
|
||||
void Write(const void* dptr, size_t size) override {
|
||||
LOG(FATAL) << "Not implemented";
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief input stream */
|
||||
dmlc::Stream *strm_;
|
||||
/*! \brief current buffer pointer */
|
||||
size_t buffer_ptr_;
|
||||
/*! \brief internal buffer */
|
||||
std::string buffer_;
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_COMMON_IO_H_
|
||||
|
||||
Reference in New Issue
Block a user