[LEARNER] refactor learner

This commit is contained in:
tqchen
2016-01-04 01:31:44 -08:00
parent 4b4b36d047
commit 0d95e863c9
14 changed files with 470 additions and 517 deletions

View File

@@ -9,12 +9,67 @@
#define XGBOOST_COMMON_IO_H_
#include <dmlc/io.h>
#include <string>
#include <cstring>
#include "./sync.h"
namespace xgboost {
namespace common {
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
/*!
* \brief Input stream that support additional PeekRead
* operation, besides read.
*/
class PeekableInStream : public dmlc::Stream {
public:
explicit PeekableInStream(dmlc::Stream* strm)
: strm_(strm), buffer_ptr_(0) {}
size_t Read(void* dptr, size_t size) override {
size_t nbuffer = buffer_.length() - buffer_ptr_;
if (nbuffer == 0) return strm_->Read(dptr, size);
if (nbuffer < size) {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, nbuffer);
buffer_ptr_ += nbuffer;
return nbuffer + strm_->Read(reinterpret_cast<char*>(dptr) + nbuffer,
size - nbuffer);
} else {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
buffer_ptr_ += size;
return size;
}
}
size_t PeekRead(void* dptr, size_t size) {
size_t nbuffer = buffer_.length() - buffer_ptr_;
if (nbuffer < size) {
buffer_ = buffer_.substr(buffer_ptr_, buffer_.length());
buffer_ptr_ = 0;
buffer_.resize(size);
size_t nadd = strm_->Read(dmlc::BeginPtr(buffer_) + nbuffer, size - nbuffer);
buffer_.resize(nbuffer + nadd);
std::memcpy(dptr, dmlc::BeginPtr(buffer_), buffer_.length());
return buffer_.length();
} else {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
return size;
}
}
void Write(const void* dptr, size_t size) override {
LOG(FATAL) << "Not implemented";
}
private:
/*! \brief input stream */
dmlc::Stream *strm_;
/*! \brief current buffer pointer */
size_t buffer_ptr_;
/*! \brief internal buffer */
std::string buffer_;
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_IO_H_

View File

@@ -1,53 +0,0 @@
/*!
* Copyright 2015 by Contributors
* \file metric_set.h
* \brief additional math utils
* \author Tianqi Chen
*/
#ifndef XGBOOST_COMMON_METRIC_SET_H_
#define XGBOOST_COMMON_METRIC_SET_H_
#include <vector>
#include <string>
namespace xgboost {
namespace common {
/*! \brief helper util to create a set of metrics */
class MetricSet {
inline void AddEval(const char *name) {
using namespace std;
for (size_t i = 0; i < evals_.size(); ++i) {
if (!strcmp(name, evals_[i]->Name())) return;
}
evals_.push_back(CreateEvaluator(name));
}
~EvalSet(void) {
for (size_t i = 0; i < evals_.size(); ++i) {
delete evals_[i];
}
}
inline std::string Eval(const char *evname,
const std::vector<float> &preds,
const MetaInfo &info,
bool distributed = false) {
std::string result = "";
for (size_t i = 0; i < evals_.size(); ++i) {
float res = evals_[i]->Eval(preds, info, distributed);
char tmp[1024];
utils::SPrintf(tmp, sizeof(tmp), "\t%s-%s:%f", evname, evals_[i]->Name(), res);
result += tmp;
}
return result;
}
inline size_t Size(void) const {
return evals_.size();
}
private:
std::vector<const IEvaluator*> evals_;
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_METRIC_SET_H_