xgboost/old_src/io/dmlc_simple.cpp
2016-01-16 10:24:00 -08:00

230 lines
6.3 KiB
C++

// Copyright by Contributors
#define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX
#include <string>
#include "../utils/io.h"
// implements a single no split version of DMLC
// in case we want to avoid dependency on dmlc-core
namespace xgboost {
namespace utils {
/*!
* \brief line split implementation from single FILE
* simply returns lines of files, used for stdin
*/
class SingleFileSplit : public dmlc::InputSplit {
public:
explicit SingleFileSplit(const char *fname)
: use_stdin_(false),
chunk_begin_(NULL), chunk_end_(NULL) {
if (!std::strcmp(fname, "stdin")) {
#ifndef XGBOOST_STRICT_CXX98_
use_stdin_ = true; fp_ = stdin;
#endif
}
if (!use_stdin_) {
fp_ = utils::FopenCheck(fname, "rb");
}
buffer_.resize(kBufferSize);
}
virtual ~SingleFileSplit(void) {
if (!use_stdin_) std::fclose(fp_);
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp_);
}
virtual void Write(const void *ptr, size_t size) {
utils::Error("cannot do write in inputsplit");
}
virtual void BeforeFirst(void) {
std::fseek(fp_, 0, SEEK_SET);
}
virtual bool NextRecord(Blob *out_rec) {
if (chunk_begin_ == chunk_end_) {
if (!LoadChunk()) return false;
}
char *next = FindNextRecord(chunk_begin_,
chunk_end_);
out_rec->dptr = chunk_begin_;
out_rec->size = next - chunk_begin_;
chunk_begin_ = next;
return true;
}
virtual bool NextChunk(Blob *out_chunk) {
if (chunk_begin_ == chunk_end_) {
if (!LoadChunk()) return false;
}
out_chunk->dptr = chunk_begin_;
out_chunk->size = chunk_end_ - chunk_begin_;
chunk_begin_ = chunk_end_;
return true;
}
inline bool ReadChunk(void *buf, size_t *size) {
size_t max_size = *size;
if (max_size <= overflow_.length()) {
*size = 0; return true;
}
if (overflow_.length() != 0) {
std::memcpy(buf, BeginPtr(overflow_), overflow_.length());
}
size_t olen = overflow_.length();
overflow_.resize(0);
size_t nread = this->Read(reinterpret_cast<char*>(buf) + olen,
max_size - olen);
nread += olen;
if (nread == 0) return false;
if (nread != max_size) {
*size = nread;
return true;
} else {
const char *bptr = reinterpret_cast<const char*>(buf);
// return the last position where a record starts
const char *bend = this->FindLastRecordBegin(bptr, bptr + max_size);
*size = bend - bptr;
overflow_.resize(max_size - *size);
if (overflow_.length() != 0) {
std::memcpy(BeginPtr(overflow_), bend, overflow_.length());
}
return true;
}
}
protected:
inline const char* FindLastRecordBegin(const char *begin,
const char *end) {
if (begin == end) return begin;
for (const char *p = end - 1; p != begin; --p) {
if (*p == '\n' || *p == '\r') return p + 1;
}
return begin;
}
inline char* FindNextRecord(char *begin, char *end) {
char *p;
for (p = begin; p != end; ++p) {
if (*p == '\n' || *p == '\r') break;
}
for (; p != end; ++p) {
if (*p != '\n' && *p != '\r') return p;
}
return end;
}
inline bool LoadChunk(void) {
while (true) {
size_t size = buffer_.length();
if (!ReadChunk(BeginPtr(buffer_), &size)) return false;
if (size == 0) {
buffer_.resize(buffer_.length() * 2);
} else {
chunk_begin_ = reinterpret_cast<char *>(BeginPtr(buffer_));
chunk_end_ = chunk_begin_ + size;
break;
}
}
return true;
}
private:
// buffer size
static const size_t kBufferSize = 1 << 18UL;
// file
std::FILE *fp_;
bool use_stdin_;
// internal overflow
std::string overflow_;
// internal buffer
std::string buffer_;
// beginning of chunk
char *chunk_begin_;
// end of chunk
char *chunk_end_;
};
class StdFile : public dmlc::Stream {
public:
explicit StdFile(std::FILE *fp, bool use_stdio)
: fp(fp), use_stdio(use_stdio) {
}
virtual ~StdFile(void) {
this->Close();
}
virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp);
}
virtual void Write(const void *ptr, size_t size) {
Check(std::fwrite(ptr, size, 1, fp) == 1, "StdFile::Write: fwrite error!");
}
virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET); // NOLINT(*)
}
virtual size_t Tell(void) {
return std::ftell(fp);
}
virtual bool AtEnd(void) const {
return std::feof(fp) != 0;
}
inline void Close(void) {
if (fp != NULL && !use_stdio) {
std::fclose(fp); fp = NULL;
}
}
private:
std::FILE *fp;
bool use_stdio;
};
} // namespace utils
} // namespace xgboost
namespace dmlc {
InputSplit* InputSplit::Create(const char *uri,
unsigned part,
unsigned nsplit,
const char *type) {
using namespace std;
using namespace xgboost;
const char *msg = "xgboost is compiled in local mode\n"\
"to use hdfs, s3 or distributed version, compile with make dmlc=1";
utils::Check(strncmp(uri, "s3://", 5) != 0, msg);
utils::Check(strncmp(uri, "hdfs://", 7) != 0, msg);
utils::Check(nsplit == 1, msg);
return new utils::SingleFileSplit(uri);
}
Stream *Stream::Create(const char *fname, const char * const mode, bool allow_null) {
using namespace std;
using namespace xgboost;
const char *msg = "xgboost is compiled in local mode\n"\
"to use hdfs, s3 or distributed version, compile with make dmlc=1";
utils::Check(strncmp(fname, "s3://", 5) != 0, msg);
utils::Check(strncmp(fname, "hdfs://", 7) != 0, msg);
std::FILE *fp = NULL;
bool use_stdio = false;
using namespace std;
#ifndef XGBOOST_STRICT_CXX98_
if (!strcmp(fname, "stdin")) {
use_stdio = true; fp = stdin;
}
if (!strcmp(fname, "stdout")) {
use_stdio = true; fp = stdout;
}
#endif
if (!strncmp(fname, "file://", 7)) fname += 7;
if (!use_stdio) {
std::string flag = mode;
if (flag == "w") flag = "wb";
if (flag == "r") flag = "rb";
fp = fopen64(fname, flag.c_str());
}
if (fp != NULL) {
return new utils::StdFile(fp, use_stdio);
} else {
utils::Check(allow_null, "fail to open file %s", fname);
return NULL;
}
}
} // namespace dmlc