This commit is contained in:
tqchen 2015-07-03 19:35:23 -07:00
parent aba41d07cd
commit 1123253f79
10 changed files with 178 additions and 143 deletions

View File

@ -1,10 +1,12 @@
#ifndef XGBOOST_DATA_H
#define XGBOOST_DATA_H
/*! /*!
* Copyright (c) 2014 by Contributors
* \file data.h * \file data.h
* \brief the input data structure for gradient boosting * \brief the input data structure for gradient boosting
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_DATA_H_
#define XGBOOST_DATA_H_
#include <cstdio> #include <cstdio>
#include <vector> #include <vector>
#include "utils/utils.h" #include "utils/utils.h"
@ -32,7 +34,7 @@ struct bst_gpair {
bst_gpair(bst_float grad, bst_float hess) : grad(grad), hess(hess) {} bst_gpair(bst_float grad, bst_float hess) : grad(grad), hess(hess) {}
}; };
/*! /*!
* \brief extra information that might needed by gbm and tree module * \brief extra information that might needed by gbm and tree module
* these information are not necessarily presented, and can be empty * these information are not necessarily presented, and can be empty
*/ */
@ -102,7 +104,7 @@ struct RowBatch : public SparseBatch {
return Inst(data_ptr + ind_ptr[i], static_cast<bst_uint>(ind_ptr[i+1] - ind_ptr[i])); return Inst(data_ptr + ind_ptr[i], static_cast<bst_uint>(ind_ptr[i+1] - ind_ptr[i]));
} }
}; };
/*! /*!
* \brief read-only column batch, used to access columns, * \brief read-only column batch, used to access columns,
* the columns are not required to be continuous * the columns are not required to be continuous
*/ */
@ -131,7 +133,7 @@ class IFMatrix {
/*!\brief get column iterator */ /*!\brief get column iterator */
virtual utils::IIterator<ColBatch> *ColIterator(void) = 0; virtual utils::IIterator<ColBatch> *ColIterator(void) = 0;
/*! /*!
* \brief get the column iterator associated with FMatrix with subset of column features * \brief get the column iterator associated with FMatrix with subset of column features
* \param fset is the list of column index set that must be contained in the returning Column iterator * \param fset is the list of column index set that must be contained in the returning Column iterator
* \return the column iterator, initialized so that it reads the elements in fset * \return the column iterator, initialized so that it reads the elements in fset
*/ */
@ -154,11 +156,11 @@ class IFMatrix {
/*! \brief get number of non-missing entries in column */ /*! \brief get number of non-missing entries in column */
virtual size_t GetColSize(size_t cidx) const = 0; virtual size_t GetColSize(size_t cidx) const = 0;
/*! \brief get column density */ /*! \brief get column density */
virtual float GetColDensity(size_t cidx) const = 0; virtual float GetColDensity(size_t cidx) const = 0;
/*! \brief reference of buffered rowset */ /*! \brief reference of buffered rowset */
virtual const std::vector<bst_uint> &buffered_rowset(void) const = 0; virtual const std::vector<bst_uint> &buffered_rowset(void) const = 0;
// virtual destructor // virtual destructor
virtual ~IFMatrix(void){} virtual ~IFMatrix(void){}
}; };
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_DATA_H #endif // XGBOOST_DATA_H_

View File

@ -1,6 +1,8 @@
// Copyright by Contributors
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX #define NOMINMAX
#include <string>
#include "../utils/io.h" #include "../utils/io.h"
// implements a single no split version of DMLC // implements a single no split version of DMLC
@ -9,7 +11,7 @@
namespace xgboost { namespace xgboost {
namespace utils { namespace utils {
/*! /*!
* \brief line split implementation from single FILE * \brief line split implementation from single FILE
* simply returns lines of files, used for stdin * simply returns lines of files, used for stdin
*/ */
class SingleFileSplit : public dmlc::InputSplit { class SingleFileSplit : public dmlc::InputSplit {
@ -32,7 +34,7 @@ class SingleFileSplit : public dmlc::InputSplit {
} }
virtual size_t Read(void *ptr, size_t size) { virtual size_t Read(void *ptr, size_t size) {
return std::fread(ptr, 1, size, fp_); return std::fread(ptr, 1, size, fp_);
} }
virtual void Write(const void *ptr, size_t size) { virtual void Write(const void *ptr, size_t size) {
utils::Error("cannot do write in inputsplit"); utils::Error("cannot do write in inputsplit");
} }
@ -47,13 +49,13 @@ class SingleFileSplit : public dmlc::InputSplit {
chunk_end_); chunk_end_);
out_rec->dptr = chunk_begin_; out_rec->dptr = chunk_begin_;
out_rec->size = next - chunk_begin_; out_rec->size = next - chunk_begin_;
chunk_begin_ = next; chunk_begin_ = next;
return true; return true;
} }
virtual bool NextChunk(Blob *out_chunk) { virtual bool NextChunk(Blob *out_chunk) {
if (chunk_begin_ == chunk_end_) { if (chunk_begin_ == chunk_end_) {
if (!LoadChunk()) return false; if (!LoadChunk()) return false;
} }
out_chunk->dptr = chunk_begin_; out_chunk->dptr = chunk_begin_;
out_chunk->size = chunk_end_ - chunk_begin_; out_chunk->size = chunk_end_ - chunk_begin_;
chunk_begin_ = chunk_end_; chunk_begin_ = chunk_end_;
@ -64,8 +66,8 @@ class SingleFileSplit : public dmlc::InputSplit {
if (max_size <= overflow_.length()) { if (max_size <= overflow_.length()) {
*size = 0; return true; *size = 0; return true;
} }
if (overflow_.length() != 0) { if (overflow_.length() != 0) {
std::memcpy(buf, BeginPtr(overflow_), overflow_.length()); std::memcpy(buf, BeginPtr(overflow_), overflow_.length());
} }
size_t olen = overflow_.length(); size_t olen = overflow_.length();
overflow_.resize(0); overflow_.resize(0);
@ -88,13 +90,13 @@ class SingleFileSplit : public dmlc::InputSplit {
return true; return true;
} }
} }
protected: protected:
inline const char* FindLastRecordBegin(const char *begin, inline const char* FindLastRecordBegin(const char *begin,
const char *end) { const char *end) {
if (begin == end) return begin; if (begin == end) return begin;
for (const char *p = end - 1; p != begin; --p) { for (const char *p = end - 1; p != begin; --p) {
if (*p == '\n' || *p == '\r') return p + 1; if (*p == '\n' || *p == '\r') return p + 1;
} }
return begin; return begin;
} }
@ -143,7 +145,7 @@ class StdFile : public dmlc::Stream {
public: public:
explicit StdFile(std::FILE *fp, bool use_stdio) explicit StdFile(std::FILE *fp, bool use_stdio)
: fp(fp), use_stdio(use_stdio) { : fp(fp), use_stdio(use_stdio) {
} }
virtual ~StdFile(void) { virtual ~StdFile(void) {
this->Close(); this->Close();
} }
@ -154,7 +156,7 @@ class StdFile : public dmlc::Stream {
std::fwrite(ptr, size, 1, fp); std::fwrite(ptr, size, 1, fp);
} }
virtual void Seek(size_t pos) { virtual void Seek(size_t pos) {
std::fseek(fp, static_cast<long>(pos), SEEK_SET); std::fseek(fp, static_cast<long>(pos), SEEK_SET); // NOLINT(*)
} }
virtual size_t Tell(void) { virtual size_t Tell(void) {
return std::ftell(fp); return std::ftell(fp);
@ -197,7 +199,7 @@ Stream *Stream::Create(const char *fname, const char * const mode, bool allow_nu
"to use hdfs, s3 or distributed version, compile with make dmlc=1"; "to use hdfs, s3 or distributed version, compile with make dmlc=1";
utils::Check(strncmp(fname, "s3://", 5) != 0, msg); utils::Check(strncmp(fname, "s3://", 5) != 0, msg);
utils::Check(strncmp(fname, "hdfs://", 7) != 0, msg); utils::Check(strncmp(fname, "hdfs://", 7) != 0, msg);
std::FILE *fp = NULL; std::FILE *fp = NULL;
bool use_stdio = false; bool use_stdio = false;
using namespace std; using namespace std;

View File

@ -1,3 +1,4 @@
// Copyright 2014 by Contributors
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX #define NOMINMAX
@ -17,7 +18,7 @@ DataMatrix* LoadDataMatrix(const char *fname,
const char *cache_file) { const char *cache_file) {
using namespace std; using namespace std;
std::string fname_ = fname; std::string fname_ = fname;
const char *dlm = strchr(fname, '#'); const char *dlm = strchr(fname, '#');
if (dlm != NULL) { if (dlm != NULL) {
utils::Check(strchr(dlm + 1, '#') == NULL, utils::Check(strchr(dlm + 1, '#') == NULL,
@ -29,7 +30,7 @@ DataMatrix* LoadDataMatrix(const char *fname,
cache_file = dlm +1; cache_file = dlm +1;
} }
if (cache_file == NULL) { if (cache_file == NULL) {
if (!std::strcmp(fname, "stdin") || if (!std::strcmp(fname, "stdin") ||
!std::strncmp(fname, "s3://", 5) || !std::strncmp(fname, "s3://", 5) ||
!std::strncmp(fname, "hdfs://", 7) || !std::strncmp(fname, "hdfs://", 7) ||
@ -42,7 +43,7 @@ DataMatrix* LoadDataMatrix(const char *fname,
utils::FileStream fs(utils::FopenCheck(fname, "rb")); utils::FileStream fs(utils::FopenCheck(fname, "rb"));
utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format"); utils::Check(fs.Read(&magic, sizeof(magic)) != 0, "invalid input file format");
fs.Seek(0); fs.Seek(0);
if (magic == DMatrixSimple::kMagic) { if (magic == DMatrixSimple::kMagic) {
DMatrixSimple *dmat = new DMatrixSimple(); DMatrixSimple *dmat = new DMatrixSimple();
dmat->LoadBinary(fs, silent, fname); dmat->LoadBinary(fs, silent, fname);
fs.Close(); fs.Close();
@ -81,7 +82,7 @@ DataMatrix* LoadDataMatrix(const char *fname,
} }
} }
void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) { void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) {
if (dmat.magic == DMatrixSimple::kMagic) { if (dmat.magic == DMatrixSimple::kMagic) {
const DMatrixSimple *p_dmat = static_cast<const DMatrixSimple*>(&dmat); const DMatrixSimple *p_dmat = static_cast<const DMatrixSimple*>(&dmat);
p_dmat->SaveBinary(fname, silent); p_dmat->SaveBinary(fname, silent);

View File

@ -22,7 +22,7 @@ namespace io {
/*! \brief page returned by libsvm parser */ /*! \brief page returned by libsvm parser */
struct LibSVMPage : public SparsePage { struct LibSVMPage : public SparsePage {
std::vector<float> label; std::vector<float> label;
// overload clear // overload clear
inline void Clear() { inline void Clear() {
SparsePage::Clear(); SparsePage::Clear();
label.clear(); label.clear();
@ -35,7 +35,7 @@ struct LibSVMPage : public SparsePage {
*/ */
class LibSVMPageFactory { class LibSVMPageFactory {
public: public:
explicit LibSVMPageFactory() LibSVMPageFactory()
: bytes_read_(0), at_head_(true) { : bytes_read_(0), at_head_(true) {
} }
inline bool Init(void) { inline bool Init(void) {
@ -85,7 +85,7 @@ class LibSVMPageFactory {
data->resize(nthread); data->resize(nthread);
bytes_read_ += chunk.size; bytes_read_ += chunk.size;
utils::Assert(chunk.size != 0, "LibSVMParser.FileData"); utils::Assert(chunk.size != 0, "LibSVMParser.FileData");
char *head = reinterpret_cast<char*>(chunk.dptr); char *head = reinterpret_cast<char*>(chunk.dptr);
#pragma omp parallel num_threads(nthread_) #pragma omp parallel num_threads(nthread_)
{ {
// threadid // threadid
@ -150,7 +150,7 @@ class LibSVMPageFactory {
} }
return begin; return begin;
} }
private: private:
// nthread // nthread
int nthread_; int nthread_;
@ -199,12 +199,13 @@ class LibSVMParser : public utils::IIterator<LibSVMPage> {
inline size_t bytes_read(void) const { inline size_t bytes_read(void) const {
return itr.get_factory().bytes_read(); return itr.get_factory().bytes_read();
} }
private: private:
bool at_end_; bool at_end_;
size_t data_ptr_; size_t data_ptr_;
std::vector<LibSVMPage> *data_; std::vector<LibSVMPage> *data_;
utils::ThreadBuffer<std::vector<LibSVMPage>*, LibSVMPageFactory> itr; utils::ThreadBuffer<std::vector<LibSVMPage>*, LibSVMPageFactory> itr;
}; };
} // namespace io } // namespace io
} // namespace xgboost } // namespace xgboost

View File

@ -1,11 +1,15 @@
#ifndef XGBOOST_IO_PAGE_DMATRIX_INL_HPP_
#define XGBOOST_IO_PAGE_DMATRIX_INL_HPP_
/*! /*!
* Copyright (c) 2014 by Contributors
* \file page_dmatrix-inl.hpp * \file page_dmatrix-inl.hpp
* row iterator based on sparse page * row iterator based on sparse page
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_IO_PAGE_DMATRIX_INL_HPP_
#define XGBOOST_IO_PAGE_DMATRIX_INL_HPP_
#include <vector> #include <vector>
#include <string>
#include <algorithm>
#include "../data.h" #include "../data.h"
#include "../utils/iterator.h" #include "../utils/iterator.h"
#include "../utils/thread_buffer.h" #include "../utils/thread_buffer.h"
@ -94,12 +98,12 @@ class DMatrixPageBase : public DataMatrix {
fbin.Close(); fbin.Close();
if (!silent) { if (!silent) {
utils::Printf("DMatrixPage: %lux%lu is saved to %s\n", utils::Printf("DMatrixPage: %lux%lu is saved to %s\n",
static_cast<unsigned long>(mat.info.num_row()), static_cast<unsigned long>(mat.info.num_row()), // NOLINT(*)
static_cast<unsigned long>(mat.info.num_col()), fname_); static_cast<unsigned long>(mat.info.num_col()), fname_); // NOLINT(*)
} }
} }
/*! \brief load and initialize the iterator with fi */ /*! \brief load and initialize the iterator with fi */
inline void LoadBinary(utils::FileStream &fi, inline void LoadBinary(utils::FileStream &fi, // NOLINT(*)
bool silent, bool silent,
const char *fname_) { const char *fname_) {
this->set_cache_file(fname_); this->set_cache_file(fname_);
@ -114,8 +118,8 @@ class DMatrixPageBase : public DataMatrix {
iter_->Load(fs); iter_->Load(fs);
if (!silent) { if (!silent) {
utils::Printf("DMatrixPage: %lux%lu matrix is loaded", utils::Printf("DMatrixPage: %lux%lu matrix is loaded",
static_cast<unsigned long>(info.num_row()), static_cast<unsigned long>(info.num_row()), // NOLINT(*)
static_cast<unsigned long>(info.num_col())); static_cast<unsigned long>(info.num_col())); // NOLINT(*)
if (fname_ != NULL) { if (fname_ != NULL) {
utils::Printf(" from %s\n", fname_); utils::Printf(" from %s\n", fname_);
} else { } else {
@ -141,7 +145,7 @@ class DMatrixPageBase : public DataMatrix {
} }
this->set_cache_file(cache_file); this->set_cache_file(cache_file);
std::string fname_row = std::string(cache_file) + ".row.blob"; std::string fname_row = std::string(cache_file) + ".row.blob";
utils::FileStream fo(utils::FopenCheck(fname_row.c_str(), "wb")); utils::FileStream fo(utils::FopenCheck(fname_row.c_str(), "wb"));
SparsePage page; SparsePage page;
size_t bytes_write = 0; size_t bytes_write = 0;
double tstart = rabit::utils::GetTime(); double tstart = rabit::utils::GetTime();
@ -178,8 +182,8 @@ class DMatrixPageBase : public DataMatrix {
if (page.data.size() != 0) { if (page.data.size() != 0) {
page.Save(&fo); page.Save(&fo);
} }
fo.Close(); fo.Close();
iter_->Load(utils::FileStream(utils::FopenCheck(fname_row.c_str(), "rb"))); iter_->Load(utils::FileStream(utils::FopenCheck(fname_row.c_str(), "rb")));
// save data matrix // save data matrix
utils::FileStream fs(utils::FopenCheck(cache_file, "wb")); utils::FileStream fs(utils::FopenCheck(cache_file, "wb"));
int tmagic = kMagic; int tmagic = kMagic;
@ -188,8 +192,8 @@ class DMatrixPageBase : public DataMatrix {
fs.Close(); fs.Close();
if (!silent) { if (!silent) {
utils::Printf("DMatrixPage: %lux%lu is parsed from %s\n", utils::Printf("DMatrixPage: %lux%lu is parsed from %s\n",
static_cast<unsigned long>(info.num_row()), static_cast<unsigned long>(info.num_row()), // NOLINT(*)
static_cast<unsigned long>(info.num_col()), static_cast<unsigned long>(info.num_col()), // NOLINT(*)
uri); uri);
} }
} }
@ -241,12 +245,12 @@ class DMatrixHalfRAM : public DMatrixPageBase<0xffffab03> {
virtual IFMatrix *fmat(void) const { virtual IFMatrix *fmat(void) const {
return fmat_; return fmat_;
} }
virtual void set_cache_file(const std::string &cache_file) { virtual void set_cache_file(const std::string &cache_file) {
} }
virtual void CheckMagic(int tmagic) { virtual void CheckMagic(int tmagic) {
utils::Check(tmagic == DMatrixPageBase<0xffffab02>::kMagic || utils::Check(tmagic == DMatrixPageBase<0xffffab02>::kMagic ||
tmagic == DMatrixPageBase<0xffffab03>::kMagic, tmagic == DMatrixPageBase<0xffffab03>::kMagic,
"invalid format,magic number mismatch"); "invalid format,magic number mismatch");
} }
/*! \brief the real fmatrix */ /*! \brief the real fmatrix */
IFMatrix *fmat_; IFMatrix *fmat_;

View File

@ -1,10 +1,16 @@
#ifndef XGBOOST_IO_PAGE_FMATRIX_INL_HPP_
#define XGBOOST_IO_PAGE_FMATRIX_INL_HPP_
/*! /*!
* Copyright (c) 2014 by Contributors
* \file page_fmatrix-inl.hpp * \file page_fmatrix-inl.hpp
* col iterator based on sparse page * col iterator based on sparse page
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_IO_PAGE_FMATRIX_INL_HPP_
#define XGBOOST_IO_PAGE_FMATRIX_INL_HPP_
#include <vector>
#include <string>
#include <algorithm>
namespace xgboost { namespace xgboost {
namespace io { namespace io {
/*! \brief thread buffer iterator */ /*! \brief thread buffer iterator */
@ -42,9 +48,9 @@ class ThreadColPageIterator: public utils::IIterator<ColBatch> {
} }
// set index set // set index set
inline void SetIndexSet(const std::vector<bst_uint> &fset, bool load_all) { inline void SetIndexSet(const std::vector<bst_uint> &fset, bool load_all) {
itr.get_factory().SetIndexSet(fset, load_all); itr.get_factory().SetIndexSet(fset, load_all);
} }
private: private:
// output data // output data
ColBatch out_; ColBatch out_;
@ -96,7 +102,7 @@ struct ColConvertFactory {
return true; return true;
} }
} }
if (tmp_.Size() != 0){ if (tmp_.Size() != 0) {
this->MakeColPage(tmp_, BeginPtr(*buffered_rowset_) + btop, this->MakeColPage(tmp_, BeginPtr(*buffered_rowset_) + btop,
*enabled_, val); *enabled_, val);
return true; return true;
@ -104,7 +110,7 @@ struct ColConvertFactory {
return false; return false;
} }
} }
inline void Destroy(void) {} inline void Destroy(void) {}
inline void BeforeFirst(void) {} inline void BeforeFirst(void) {}
inline void MakeColPage(const SparsePage &prow, inline void MakeColPage(const SparsePage &prow,
const bst_uint *ridx, const bst_uint *ridx,
@ -115,7 +121,7 @@ struct ColConvertFactory {
#pragma omp parallel #pragma omp parallel
{ {
nthread = omp_get_num_threads(); nthread = omp_get_num_threads();
int max_nthread = std::max(omp_get_num_procs() / 2 - 4, 1); int max_nthread = std::max(omp_get_num_procs() / 2 - 4, 1);
if (nthread > max_nthread) { if (nthread > max_nthread) {
nthread = max_nthread; nthread = max_nthread;
} }
@ -130,10 +136,10 @@ struct ColConvertFactory {
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) { for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
const SparseBatch::Entry &e = prow.data[j]; const SparseBatch::Entry &e = prow.data[j];
if (enabled[e.index]) { if (enabled[e.index]) {
builder.AddBudget(e.index, tid); builder.AddBudget(e.index, tid);
} }
} }
} }
builder.InitStorage(); builder.InitStorage();
#pragma omp parallel for schedule(static) num_threads(nthread) #pragma omp parallel for schedule(static) num_threads(nthread)
@ -169,7 +175,7 @@ struct ColConvertFactory {
// buffered rowset // buffered rowset
std::vector<bst_uint> *buffered_rowset_; std::vector<bst_uint> *buffered_rowset_;
// enabled marks // enabled marks
const std::vector<bool> *enabled_; const std::vector<bool> *enabled_;
// internal temp cache // internal temp cache
SparsePage tmp_; SparsePage tmp_;
/*! \brief page size 256 M */ /*! \brief page size 256 M */
@ -191,7 +197,7 @@ class FMatrixPage : public IFMatrix {
if (iter_ != NULL) delete iter_; if (iter_ != NULL) delete iter_;
} }
/*! \return whether column access is enabled */ /*! \return whether column access is enabled */
virtual bool HaveColAccess(void) const { virtual bool HaveColAccess(void) const {
return col_size_.size() != 0; return col_size_.size() != 0;
} }
/*! \brief get number of colmuns */ /*! \brief get number of colmuns */
@ -212,7 +218,7 @@ class FMatrixPage : public IFMatrix {
size_t nmiss = num_buffered_row_ - (col_size_[cidx]); size_t nmiss = num_buffered_row_ - (col_size_[cidx]);
return 1.0f - (static_cast<float>(nmiss)) / num_buffered_row_; return 1.0f - (static_cast<float>(nmiss)) / num_buffered_row_;
} }
virtual void InitColAccess(const std::vector<bool> &enabled, virtual void InitColAccess(const std::vector<bool> &enabled,
float pkeep, size_t max_row_perbatch) { float pkeep, size_t max_row_perbatch) {
if (this->HaveColAccess()) return; if (this->HaveColAccess()) return;
if (TryLoadColData()) return; if (TryLoadColData()) return;
@ -242,11 +248,11 @@ class FMatrixPage : public IFMatrix {
/*! /*!
* \brief colmun based iterator * \brief colmun based iterator
*/ */
virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) { virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) {
size_t ncol = this->NumCol(); size_t ncol = this->NumCol();
col_index_.resize(0); col_index_.resize(0);
for (size_t i = 0; i < fset.size(); ++i) { for (size_t i = 0; i < fset.size(); ++i) {
if (fset[i] < ncol) col_index_.push_back(fset[i]); if (fset[i] < ncol) col_index_.push_back(fset[i]);
} }
col_iter_.SetIndexSet(col_index_, false); col_iter_.SetIndexSet(col_index_, false);
col_iter_.BeforeFirst(); col_iter_.BeforeFirst();
@ -255,13 +261,13 @@ class FMatrixPage : public IFMatrix {
// set the cache file name // set the cache file name
inline void set_cache_file(const std::string &cache_file) { inline void set_cache_file(const std::string &cache_file) {
col_data_name_ = std::string(cache_file) + ".col.blob"; col_data_name_ = std::string(cache_file) + ".col.blob";
col_meta_name_ = std::string(cache_file) + ".col.meta"; col_meta_name_ = std::string(cache_file) + ".col.meta";
} }
protected: protected:
inline bool TryLoadColData(void) { inline bool TryLoadColData(void) {
std::FILE *fi = fopen64(col_meta_name_.c_str(), "rb"); std::FILE *fi = fopen64(col_meta_name_.c_str(), "rb");
if (fi == NULL) return false; if (fi == NULL) return false;
utils::FileStream fs(fi); utils::FileStream fs(fi);
LoadMeta(&fs); LoadMeta(&fs);
fs.Close(); fs.Close();
@ -306,12 +312,12 @@ class FMatrixPage : public IFMatrix {
SparsePage *pcol; SparsePage *pcol;
while (citer.Next(pcol)) { while (citer.Next(pcol)) {
for (size_t i = 0; i < pcol->Size(); ++i) { for (size_t i = 0; i < pcol->Size(); ++i) {
col_size_[i] += pcol->offset[i + 1] - pcol->offset[i]; col_size_[i] += pcol->offset[i + 1] - pcol->offset[i];
} }
pcol->Save(&fo); pcol->Save(&fo);
size_t spage = pcol->MemCostBytes(); size_t spage = pcol->MemCostBytes();
bytes_write += spage; bytes_write += spage;
double tnow = rabit::utils::GetTime(); double tnow = rabit::utils::GetTime();
double tdiff = tnow - tstart; double tdiff = tnow - tstart;
utils::Printf("Writting to %s in %g MB/s, %lu MB written current speed:%g MB/s\n", utils::Printf("Writting to %s in %g MB/s, %lu MB written current speed:%g MB/s\n",
col_data_name_.c_str(), col_data_name_.c_str(),

View File

@ -1,13 +1,15 @@
#ifndef XGBOOST_IO_SIMPLE_DMATRIX_INL_HPP_
#define XGBOOST_IO_SIMPLE_DMATRIX_INL_HPP_
/*! /*!
* Copyright 2014 by Contributors
* \file simple_dmatrix-inl.hpp * \file simple_dmatrix-inl.hpp
* \brief simple implementation of DMatrixS that can be used * \brief simple implementation of DMatrixS that can be used
* the data format of xgboost is templatized, which means it can accept * the data format of xgboost is templatized, which means it can accept
* any data structure that implements the function defined by FMatrix * any data structure that implements the function defined by FMatrix
* this file is a specific implementation of input data structure that can be used by BoostLearner * this file is a specific implementation of input data structure that can be used by BoostLearner
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_IO_SIMPLE_DMATRIX_INL_HPP_
#define XGBOOST_IO_SIMPLE_DMATRIX_INL_HPP_
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
@ -119,13 +121,13 @@ class DMatrixSimple : public DataMatrix {
for (size_t i = 0; i < batch.data.size(); ++i) { for (size_t i = 0; i < batch.data.size(); ++i) {
info.info.num_col = std::max(info.info.num_col, info.info.num_col = std::max(info.info.num_col,
static_cast<size_t>(batch.data[i].index+1)); static_cast<size_t>(batch.data[i].index+1));
} }
} }
if (!silent) { if (!silent) {
utils::Printf("%lux%lu matrix with %lu entries is loaded from %s\n", utils::Printf("%lux%lu matrix with %lu entries is loaded from %s\n",
static_cast<unsigned long>(info.num_row()), static_cast<unsigned long>(info.num_row()), // NOLINT(*)
static_cast<unsigned long>(info.num_col()), static_cast<unsigned long>(info.num_col()), // NOLINT(*)
static_cast<unsigned long>(row_data_.size()), uri); static_cast<unsigned long>(row_data_.size()), uri); // NOLINT(*)
} }
// try to load in additional file // try to load in additional file
if (!loadsplit) { if (!loadsplit) {
@ -141,7 +143,7 @@ class DMatrixSimple : public DataMatrix {
"DMatrix: weight data does not match the number of rows in features"); "DMatrix: weight data does not match the number of rows in features");
} }
std::string mname = name + ".base_margin"; std::string mname = name + ".base_margin";
if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) { if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) {
} }
} }
} }
@ -165,10 +167,11 @@ class DMatrixSimple : public DataMatrix {
* \param silent whether print information during loading * \param silent whether print information during loading
* \param fname file name, used to print message * \param fname file name, used to print message
*/ */
inline void LoadBinary(utils::IStream &fs, bool silent = false, const char *fname = NULL) { inline void LoadBinary(utils::IStream &fs, bool silent = false, const char *fname = NULL) { // NOLINT(*)
int tmagic; int tmagic;
utils::Check(fs.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format"); utils::Check(fs.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format");
utils::Check(tmagic == kMagic, "\"%s\" invalid format, magic number mismatch", fname == NULL ? "" : fname); utils::Check(tmagic == kMagic, "\"%s\" invalid format, magic number mismatch",
fname == NULL ? "" : fname);
info.LoadBinary(fs); info.LoadBinary(fs);
LoadBinary(fs, &row_ptr_, &row_data_); LoadBinary(fs, &row_ptr_, &row_data_);
@ -176,9 +179,9 @@ class DMatrixSimple : public DataMatrix {
if (!silent) { if (!silent) {
utils::Printf("%lux%lu matrix with %lu entries is loaded", utils::Printf("%lux%lu matrix with %lu entries is loaded",
static_cast<unsigned long>(info.num_row()), static_cast<unsigned long>(info.num_row()), // NOLINT(*)
static_cast<unsigned long>(info.num_col()), static_cast<unsigned long>(info.num_col()), // NOLINT(*)
static_cast<unsigned long>(row_data_.size())); static_cast<unsigned long>(row_data_.size())); // NOLINT(*)
if (fname != NULL) { if (fname != NULL) {
utils::Printf(" from %s\n", fname); utils::Printf(" from %s\n", fname);
} else { } else {
@ -205,9 +208,9 @@ class DMatrixSimple : public DataMatrix {
if (!silent) { if (!silent) {
utils::Printf("%lux%lu matrix with %lu entries is saved to %s\n", utils::Printf("%lux%lu matrix with %lu entries is saved to %s\n",
static_cast<unsigned long>(info.num_row()), static_cast<unsigned long>(info.num_row()), // NOLINT(*)
static_cast<unsigned long>(info.num_col()), static_cast<unsigned long>(info.num_col()), // NOLINT(*)
static_cast<unsigned long>(row_data_.size()), fname); static_cast<unsigned long>(row_data_.size()), fname); // NOLINT(*)
if (info.group_ptr.size() != 0) { if (info.group_ptr.size() != 0) {
utils::Printf("data contains %u groups\n", utils::Printf("data contains %u groups\n",
static_cast<unsigned>(info.group_ptr.size()-1)); static_cast<unsigned>(info.group_ptr.size()-1));
@ -256,7 +259,7 @@ class DMatrixSimple : public DataMatrix {
* \param ptr pointer data * \param ptr pointer data
* \param data data content * \param data data content
*/ */
inline static void SaveBinary(utils::IStream &fo, inline static void SaveBinary(utils::IStream &fo, // NOLINT(*)
const std::vector<size_t> &ptr, const std::vector<size_t> &ptr,
const std::vector<RowBatch::Entry> &data) { const std::vector<RowBatch::Entry> &data) {
size_t nrow = ptr.size() - 1; size_t nrow = ptr.size() - 1;
@ -272,7 +275,7 @@ class DMatrixSimple : public DataMatrix {
* \param out_ptr pointer data * \param out_ptr pointer data
* \param out_data data content * \param out_data data content
*/ */
inline static void LoadBinary(utils::IStream &fi, inline static void LoadBinary(utils::IStream &fi, // NOLINT(*)
std::vector<size_t> *out_ptr, std::vector<size_t> *out_ptr,
std::vector<RowBatch::Entry> *out_data) { std::vector<RowBatch::Entry> *out_data) {
size_t nrow; size_t nrow;
@ -314,7 +317,7 @@ class DMatrixSimple : public DataMatrix {
DMatrixSimple *parent_; DMatrixSimple *parent_;
// temporal space for batch // temporal space for batch
RowBatch batch_; RowBatch batch_;
}; };
}; };
} // namespace io } // namespace io
} // namespace xgboost } // namespace xgboost

View File

@ -1,11 +1,15 @@
#ifndef XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP_
#define XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP_
/*! /*!
* Copyright 2014 by Contributors
* \file simple_fmatrix-inl.hpp * \file simple_fmatrix-inl.hpp
* \brief the input data structure for gradient boosting * \brief the input data structure for gradient boosting
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP_
#define XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP_
#include <limits> #include <limits>
#include <algorithm>
#include <vector>
#include "../data.h" #include "../data.h"
#include "../utils/utils.h" #include "../utils/utils.h"
#include "../utils/random.h" #include "../utils/random.h"
@ -30,7 +34,7 @@ class FMatrixS : public IFMatrix {
} }
// destructor // destructor
virtual ~FMatrixS(void) { virtual ~FMatrixS(void) {
if (iter_ != NULL) delete iter_; if (iter_ != NULL) delete iter_;
} }
/*! \return whether column access is enabled */ /*! \return whether column access is enabled */
virtual bool HaveColAccess(void) const { virtual bool HaveColAccess(void) const {
@ -54,7 +58,7 @@ class FMatrixS : public IFMatrix {
size_t nmiss = buffered_rowset_.size() - col_size_[cidx]; size_t nmiss = buffered_rowset_.size() - col_size_[cidx];
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size(); return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
} }
virtual void InitColAccess(const std::vector<bool> &enabled, virtual void InitColAccess(const std::vector<bool> &enabled,
float pkeep, size_t max_row_perbatch) { float pkeep, size_t max_row_perbatch) {
if (this->HaveColAccess()) return; if (this->HaveColAccess()) return;
this->InitColData(enabled, pkeep, max_row_perbatch); this->InitColData(enabled, pkeep, max_row_perbatch);
@ -85,7 +89,7 @@ class FMatrixS : public IFMatrix {
size_t ncol = this->NumCol(); size_t ncol = this->NumCol();
col_iter_.col_index_.resize(0); col_iter_.col_index_.resize(0);
for (size_t i = 0; i < fset.size(); ++i) { for (size_t i = 0; i < fset.size(); ++i) {
if (fset[i] < ncol) col_iter_.col_index_.push_back(fset[i]); if (fset[i] < ncol) col_iter_.col_index_.push_back(fset[i]);
} }
col_iter_.BeforeFirst(); col_iter_.BeforeFirst();
return &col_iter_; return &col_iter_;
@ -94,7 +98,7 @@ class FMatrixS : public IFMatrix {
* \brief save column access data into stream * \brief save column access data into stream
* \param fo output stream to save to * \param fo output stream to save to
*/ */
inline void SaveColAccess(utils::IStream &fo) const { inline void SaveColAccess(utils::IStream &fo) const { // NOLINT(*)
size_t n = 0; size_t n = 0;
fo.Write(&n, sizeof(n)); fo.Write(&n, sizeof(n));
} }
@ -102,10 +106,10 @@ class FMatrixS : public IFMatrix {
* \brief load column access data from stream * \brief load column access data from stream
* \param fo output stream to load from * \param fo output stream to load from
*/ */
inline void LoadColAccess(utils::IStream &fi) { inline void LoadColAccess(utils::IStream &fi) { // NOLINT(*)
// do nothing in load col access // do nothing in load col access
} }
protected: protected:
/*! /*!
* \brief intialize column data * \brief intialize column data
@ -129,7 +133,7 @@ class FMatrixS : public IFMatrix {
for (size_t i = 0; i < col_iter_.cpages_.size(); ++i) { for (size_t i = 0; i < col_iter_.cpages_.size(); ++i) {
SparsePage *pcol = col_iter_.cpages_[i]; SparsePage *pcol = col_iter_.cpages_[i];
for (size_t j = 0; j < pcol->Size(); ++j) { for (size_t j = 0; j < pcol->Size(); ++j) {
col_size_[j] += pcol->offset[j + 1] - pcol->offset[j]; col_size_[j] += pcol->offset[j + 1] - pcol->offset[j];
} }
} }
} }
@ -139,7 +143,7 @@ class FMatrixS : public IFMatrix {
* \param pcol the target column * \param pcol the target column
*/ */
inline void MakeOneBatch(const std::vector<bool> &enabled, inline void MakeOneBatch(const std::vector<bool> &enabled,
float pkeep, float pkeep,
SparsePage *pcol) { SparsePage *pcol) {
// clear rowset // clear rowset
buffered_rowset_.clear(); buffered_rowset_.clear();
@ -159,8 +163,8 @@ class FMatrixS : public IFMatrix {
while (iter_->Next()) { while (iter_->Next()) {
const RowBatch &batch = iter_->Value(); const RowBatch &batch = iter_->Value();
bmap.resize(bmap.size() + batch.size, true); bmap.resize(bmap.size() + batch.size, true);
long batch_size = static_cast<long>(batch.size); long batch_size = static_cast<long>(batch.size); // NOLINT(*)
for (long i = 0; i < batch_size; ++i) { for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i); bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (pkeep == 1.0f || random::SampleBinary(pkeep)) { if (pkeep == 1.0f || random::SampleBinary(pkeep)) {
buffered_rowset_.push_back(ridx); buffered_rowset_.push_back(ridx);
@ -169,13 +173,13 @@ class FMatrixS : public IFMatrix {
} }
} }
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (long i = 0; i < batch_size; ++i) { for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i); bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (bmap[ridx]) { if (bmap[ridx]) {
RowBatch::Inst inst = batch[i]; RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) { for (bst_uint j = 0; j < inst.length; ++j) {
if (enabled[inst[j].index]){ if (enabled[inst[j].index]) {
builder.AddBudget(inst[j].index, tid); builder.AddBudget(inst[j].index, tid);
} }
} }
@ -183,18 +187,18 @@ class FMatrixS : public IFMatrix {
} }
} }
builder.InitStorage(); builder.InitStorage();
iter_->BeforeFirst(); iter_->BeforeFirst();
while (iter_->Next()) { while (iter_->Next()) {
const RowBatch &batch = iter_->Value(); const RowBatch &batch = iter_->Value();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (long i = 0; i < static_cast<long>(batch.size); ++i) { for (long i = 0; i < static_cast<long>(batch.size); ++i) { // NOLINT(*)
int tid = omp_get_thread_num(); int tid = omp_get_thread_num();
bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i); bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (bmap[ridx]) { if (bmap[ridx]) {
RowBatch::Inst inst = batch[i]; RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) { for (bst_uint j = 0; j < inst.length; ++j) {
if (enabled[inst[j].index]) { if (enabled[inst[j].index]) {
builder.Push(inst[j].index, builder.Push(inst[j].index,
Entry((bst_uint)(batch.base_rowid+i), Entry((bst_uint)(batch.base_rowid+i),
inst[j].fvalue), tid); inst[j].fvalue), tid);
@ -261,7 +265,7 @@ class FMatrixS : public IFMatrix {
#pragma omp parallel #pragma omp parallel
{ {
nthread = omp_get_num_threads(); nthread = omp_get_num_threads();
int max_nthread = std::max(omp_get_num_procs() / 2 - 2, 1); int max_nthread = std::max(omp_get_num_procs() / 2 - 2, 1);
if (nthread > max_nthread) { if (nthread > max_nthread) {
nthread = max_nthread; nthread = max_nthread;
} }
@ -277,7 +281,7 @@ class FMatrixS : public IFMatrix {
RowBatch::Inst inst = batch[i]; RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) { for (bst_uint j = 0; j < inst.length; ++j) {
const SparseBatch::Entry &e = inst[j]; const SparseBatch::Entry &e = inst[j];
if (enabled[e.index]) { if (enabled[e.index]) {
builder.AddBudget(e.index, tid); builder.AddBudget(e.index, tid);
} }
} }
@ -330,10 +334,10 @@ class FMatrixS : public IFMatrix {
static_cast<bst_uint>(pcol->offset[ridx + 1] - pcol->offset[ridx])); static_cast<bst_uint>(pcol->offset[ridx + 1] - pcol->offset[ridx]));
} }
batch_.col_index = BeginPtr(col_index_); batch_.col_index = BeginPtr(col_index_);
batch_.col_data = BeginPtr(col_data_); batch_.col_data = BeginPtr(col_data_);
return true; return true;
} }
virtual const ColBatch &Value(void) const { virtual const ColBatch &Value(void) const {
return batch_; return batch_;
} }
inline void Clear(void) { inline void Clear(void) {
@ -347,7 +351,7 @@ class FMatrixS : public IFMatrix {
// column content // column content
std::vector<ColBatch::Inst> col_data_; std::vector<ColBatch::Inst> col_data_;
// column sparse pages // column sparse pages
std::vector<SparsePage*> cpages_; std::vector<SparsePage*> cpages_;
// data pointer // data pointer
size_t data_ptr_; size_t data_ptr_;
// temporal space for batch // temporal space for batch
@ -357,7 +361,7 @@ class FMatrixS : public IFMatrix {
// column iterator // column iterator
ColBatchIter col_iter_; ColBatchIter col_iter_;
// shared meta info with DMatrix // shared meta info with DMatrix
const learner::MetaInfo &info_; const learner::MetaInfo &info_;
// row iterator // row iterator
utils::IIterator<RowBatch> *iter_; utils::IIterator<RowBatch> *iter_;
/*! \brief list of row index that are buffered */ /*! \brief list of row index that are buffered */
@ -367,4 +371,4 @@ class FMatrixS : public IFMatrix {
}; };
} // namespace io } // namespace io
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_IO_SLICE_FMATRIX_INL_HPP #endif // XGBOOST_IO_SLICE_FMATRIX_INL_HPP_

View File

@ -1,18 +1,22 @@
#ifndef XGBOOST_IO_SPARSE_BATCH_PAGE_H_
#define XGBOOST_IO_SPARSE_BATCH_PAGE_H_
/*! /*!
* Copyright (c) 2014 by Contributors
* \file sparse_batch_page.h * \file sparse_batch_page.h
* content holder of sparse batch that can be saved to disk * content holder of sparse batch that can be saved to disk
* the representation can be effectively * the representation can be effectively
* use in external memory computation * use in external memory computation
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#ifndef XGBOOST_IO_SPARSE_BATCH_PAGE_H_
#define XGBOOST_IO_SPARSE_BATCH_PAGE_H_
#include <vector>
#include <algorithm>
#include "../data.h" #include "../data.h"
namespace xgboost { namespace xgboost {
namespace io { namespace io {
/*! /*!
* \brief storage unit of sparse batch * \brief storage unit of sparse batch
*/ */
class SparsePage { class SparsePage {
public: public:
@ -96,7 +100,7 @@ class SparsePage {
} }
/*! /*!
* \brief save the data to fo, when a page was written * \brief save the data to fo, when a page was written
* to disk it must contain all the elements in the * to disk it must contain all the elements in the
* \param fo output stream * \param fo output stream
*/ */
inline void Save(utils::IStream *fo) const { inline void Save(utils::IStream *fo) const {
@ -124,7 +128,7 @@ class SparsePage {
*/ */
inline bool PushLoad(utils::IStream *fi) { inline bool PushLoad(utils::IStream *fi) {
if (!fi->Read(&disk_offset_)) return false; if (!fi->Read(&disk_offset_)) return false;
data.resize(offset.back() + disk_offset_.back()); data.resize(offset.back() + disk_offset_.back());
if (disk_offset_.back() != 0) { if (disk_offset_.back() != 0) {
utils::Check(fi->Read(BeginPtr(data) + offset.back(), utils::Check(fi->Read(BeginPtr(data) + offset.back(),
disk_offset_.back() * sizeof(SparseBatch::Entry)) != 0, disk_offset_.back() * sizeof(SparseBatch::Entry)) != 0,
@ -138,7 +142,7 @@ class SparsePage {
} }
return true; return true;
} }
/*! /*!
* \brief Push row batch into the page * \brief Push row batch into the page
* \param batch the row batch * \param batch the row batch
*/ */
@ -154,7 +158,7 @@ class SparsePage {
offset[i + begin] = top + batch.ind_ptr[i + 1] - batch.ind_ptr[0]; offset[i + begin] = top + batch.ind_ptr[i + 1] - batch.ind_ptr[0];
} }
} }
/*! /*!
* \brief Push a sparse page * \brief Push a sparse page
* \param batch the row page * \param batch the row page
*/ */
@ -170,7 +174,7 @@ class SparsePage {
offset[i + begin] = top + batch.offset[i + 1]; offset[i + begin] = top + batch.offset[i + 1];
} }
} }
/*! /*!
* \brief Push one instance into page * \brief Push one instance into page
* \param row an instance row * \param row an instance row
*/ */
@ -202,7 +206,7 @@ class SparsePage {
}; };
/*! /*!
* \brief factory class for SparsePage, * \brief factory class for SparsePage,
* used in threadbuffer template * used in threadbuffer template
*/ */
class SparsePageFactory { class SparsePageFactory {
public: public:
@ -217,7 +221,7 @@ class SparsePageFactory {
return action_index_set_; return action_index_set_;
} }
// set index set, will be used after next before first // set index set, will be used after next before first
inline void SetIndexSet(const std::vector<bst_uint> &index_set, inline void SetIndexSet(const std::vector<bst_uint> &index_set,
bool load_all) { bool load_all) {
set_load_all_ = load_all; set_load_all_ = load_all;
if (!set_load_all_) { if (!set_load_all_) {
@ -229,7 +233,7 @@ class SparsePageFactory {
return true; return true;
} }
inline void SetParam(const char *name, const char *val) {} inline void SetParam(const char *name, const char *val) {}
inline bool LoadNext(SparsePage *val) { inline bool LoadNext(SparsePage *val) {
if (!action_load_all_) { if (!action_load_all_) {
if (action_index_set_.size() == 0) { if (action_index_set_.size() == 0) {
return false; return false;

View File

@ -1,18 +1,20 @@
// Copyright 2014 by Contributors
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE
#define NOMINMAX #define NOMINMAX
#include <ctime> #include <ctime>
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <vector>
#include "./sync/sync.h" #include "./sync/sync.h"
#include "io/io.h" #include "./io/io.h"
#include "utils/utils.h" #include "./utils/utils.h"
#include "utils/config.h" #include "./utils/config.h"
#include "learner/learner-inl.hpp" #include "./learner/learner-inl.hpp"
namespace xgboost { namespace xgboost {
/*! /*!
* \brief wrapping the training process * \brief wrapping the training process
*/ */
class BoostLearnTask { class BoostLearnTask {
public: public:
@ -20,7 +22,7 @@ class BoostLearnTask {
if (argc < 2) { if (argc < 2) {
printf("Usage: <config>\n"); printf("Usage: <config>\n");
return 0; return 0;
} }
utils::ConfigIterator itr(argv[1]); utils::ConfigIterator itr(argv[1]);
while (itr.Next()) { while (itr.Next()) {
this->SetParam(itr.name(), itr.val()); this->SetParam(itr.name(), itr.val());
@ -44,10 +46,10 @@ class BoostLearnTask {
} }
if (rabit::IsDistributed() && data_split == "NONE") { if (rabit::IsDistributed() && data_split == "NONE") {
this->SetParam("dsplit", "row"); this->SetParam("dsplit", "row");
} }
if (rabit::GetRank() != 0) { if (rabit::GetRank() != 0) {
this->SetParam("silent", "2"); this->SetParam("silent", "2");
} }
this->InitData(); this->InitData();
if (task == "train") { if (task == "train") {
@ -90,12 +92,14 @@ class BoostLearnTask {
if (!strcmp("save_pbuffer", name)) save_with_pbuffer = atoi(val); if (!strcmp("save_pbuffer", name)) save_with_pbuffer = atoi(val);
if (!strncmp("eval[", name, 5)) { if (!strncmp("eval[", name, 5)) {
char evname[256]; char evname[256];
utils::Assert(sscanf(name, "eval[%[^]]", evname) == 1, "must specify evaluation name for display"); utils::Assert(sscanf(name, "eval[%[^]]", evname) == 1,
"must specify evaluation name for display");
eval_data_names.push_back(std::string(evname)); eval_data_names.push_back(std::string(evname));
eval_data_paths.push_back(std::string(val)); eval_data_paths.push_back(std::string(val));
} }
learner.SetParam(name, val); learner.SetParam(name, val);
} }
public: public:
BoostLearnTask(void) { BoostLearnTask(void) {
// default parameters // default parameters
@ -119,12 +123,13 @@ class BoostLearnTask {
save_with_pbuffer = 0; save_with_pbuffer = 0;
data = NULL; data = NULL;
} }
~BoostLearnTask(void){ ~BoostLearnTask(void) {
for (size_t i = 0; i < deval.size(); i++){ for (size_t i = 0; i < deval.size(); i++) {
delete deval[i]; delete deval[i];
} }
if (data != NULL) delete data; if (data != NULL) delete data;
} }
private: private:
inline void InitData(void) { inline void InitData(void) {
if (strchr(train_path.c_str(), '%') != NULL) { if (strchr(train_path.c_str(), '%') != NULL) {
@ -151,14 +156,14 @@ class BoostLearnTask {
loadsplit)); loadsplit));
devalall.push_back(deval.back()); devalall.push_back(deval.back());
} }
std::vector<io::DataMatrix *> dcache(1, data); std::vector<io::DataMatrix *> dcache(1, data);
for (size_t i = 0; i < deval.size(); ++ i) { for (size_t i = 0; i < deval.size(); ++i) {
dcache.push_back(deval[i]); dcache.push_back(deval[i]);
} }
// set cache data to be all training and evaluation data // set cache data to be all training and evaluation data
learner.SetCacheData(dcache); learner.SetCacheData(dcache);
// add training set to evaluation set if needed // add training set to evaluation set if needed
if (eval_train != 0) { if (eval_train != 0) {
devalall.push_back(data); devalall.push_back(data);
@ -178,13 +183,13 @@ class BoostLearnTask {
int version = rabit::LoadCheckPoint(&learner); int version = rabit::LoadCheckPoint(&learner);
if (version == 0) this->InitLearner(); if (version == 0) this->InitLearner();
const time_t start = time(NULL); const time_t start = time(NULL);
unsigned long elapsed = 0; unsigned long elapsed = 0; // NOLINT(*)
learner.CheckInit(data); learner.CheckInit(data);
bool allow_lazy = learner.AllowLazyCheckPoint(); bool allow_lazy = learner.AllowLazyCheckPoint();
for (int i = version / 2; i < num_round; ++i) { for (int i = version / 2; i < num_round; ++i) {
elapsed = (unsigned long)(time(NULL) - start); elapsed = (unsigned long)(time(NULL) - start); // NOLINT(*)
if (version % 2 == 0) { if (version % 2 == 0) {
if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed); if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed);
learner.UpdateOneIter(i, *data); learner.UpdateOneIter(i, *data);
if (allow_lazy) { if (allow_lazy) {
@ -196,7 +201,7 @@ class BoostLearnTask {
} }
utils::Assert(version == rabit::VersionNumber(), "consistent check"); utils::Assert(version == rabit::VersionNumber(), "consistent check");
std::string res = learner.EvalOneIter(i, devalall, eval_data_names); std::string res = learner.EvalOneIter(i, devalall, eval_data_names);
if (rabit::IsDistributed()){ if (rabit::IsDistributed()) {
if (rabit::GetRank() == 0) { if (rabit::GetRank() == 0) {
rabit::TrackerPrintf("%s\n", res.c_str()); rabit::TrackerPrintf("%s\n", res.c_str());
} }
@ -215,29 +220,29 @@ class BoostLearnTask {
} }
version += 1; version += 1;
utils::Assert(version == rabit::VersionNumber(), "consistent check"); utils::Assert(version == rabit::VersionNumber(), "consistent check");
elapsed = (unsigned long)(time(NULL) - start); elapsed = (unsigned long)(time(NULL) - start); // NOLINT(*)
} }
// always save final round // always save final round
if ((save_period == 0 || num_round % save_period != 0) && model_out != "NONE") { if ((save_period == 0 || num_round % save_period != 0) && model_out != "NONE") {
if (model_out == "NULL"){ if (model_out == "NULL") {
this->SaveModel(num_round - 1); this->SaveModel(num_round - 1);
} else { } else {
this->SaveModel(model_out.c_str()); this->SaveModel(model_out.c_str());
} }
} }
if (!silent){ if (!silent) {
printf("\nupdating end, %lu sec in all\n", elapsed); printf("\nupdating end, %lu sec in all\n", elapsed);
} }
} }
inline void TaskEval(void) { inline void TaskEval(void) {
learner.EvalOneIter(0, devalall, eval_data_names); learner.EvalOneIter(0, devalall, eval_data_names);
} }
inline void TaskDump(void){ inline void TaskDump(void) {
FILE *fo = utils::FopenCheck(name_dump.c_str(), "w"); FILE *fo = utils::FopenCheck(name_dump.c_str(), "w");
std::vector<std::string> dump = learner.DumpModel(fmap, dump_model_stats != 0); std::vector<std::string> dump = learner.DumpModel(fmap, dump_model_stats != 0);
for (size_t i = 0; i < dump.size(); ++ i) { for (size_t i = 0; i < dump.size(); ++i) {
fprintf(fo,"booster[%lu]:\n", i); fprintf(fo, "booster[%lu]:\n", i);
fprintf(fo,"%s", dump[i].c_str()); fprintf(fo, "%s", dump[i].c_str());
} }
fclose(fo); fclose(fo);
} }
@ -247,14 +252,15 @@ class BoostLearnTask {
} }
inline void SaveModel(int i) const { inline void SaveModel(int i) const {
char fname[256]; char fname[256];
sprintf(fname, "%s/%04d.model", model_dir_path.c_str(), i + 1); utils::SPrintf(fname, sizeof(fname),
"%s/%04d.model", model_dir_path.c_str(), i + 1);
this->SaveModel(fname); this->SaveModel(fname);
} }
inline void TaskPred(void) { inline void TaskPred(void) {
std::vector<float> preds; std::vector<float> preds;
if (!silent) printf("start prediction...\n"); if (!silent) printf("start prediction...\n");
learner.Predict(*data, pred_margin != 0, &preds, ntree_limit); learner.Predict(*data, pred_margin != 0, &preds, ntree_limit);
if (!silent) printf("writing prediction to %s\n", name_pred.c_str()); if (!silent) printf("writing prediction to %s\n", name_pred.c_str());
FILE *fo; FILE *fo;
if (name_pred != "stdout") { if (name_pred != "stdout") {
fo = utils::FopenCheck(name_pred.c_str(), "w"); fo = utils::FopenCheck(name_pred.c_str(), "w");
@ -266,6 +272,7 @@ class BoostLearnTask {
} }
if (fo != stdout) fclose(fo); if (fo != stdout) fclose(fo);
} }
private: private:
/*! \brief whether silent */ /*! \brief whether silent */
int silent; int silent;
@ -273,7 +280,7 @@ class BoostLearnTask {
int load_part; int load_part;
/*! \brief whether use auto binary buffer */ /*! \brief whether use auto binary buffer */
int use_buffer; int use_buffer;
/*! \brief whether evaluate training statistics */ /*! \brief whether evaluate training statistics */
int eval_train; int eval_train;
/*! \brief number of boosting iterations */ /*! \brief number of boosting iterations */
int num_round; int num_round;
@ -309,6 +316,7 @@ class BoostLearnTask {
std::vector<std::string> eval_data_paths; std::vector<std::string> eval_data_paths;
/*! \brief the names of the evaluation data used in output log */ /*! \brief the names of the evaluation data used in output log */
std::vector<std::string> eval_data_names; std::vector<std::string> eval_data_names;
private: private:
io::DataMatrix* data; io::DataMatrix* data;
std::vector<io::DataMatrix*> deval; std::vector<io::DataMatrix*> deval;
@ -316,9 +324,9 @@ class BoostLearnTask {
utils::FeatMap fmap; utils::FeatMap fmap;
learner::BoostLearner learner; learner::BoostLearner learner;
}; };
} } // namespace xgboost
int main(int argc, char *argv[]){ int main(int argc, char *argv[]) {
xgboost::BoostLearnTask tsk; xgboost::BoostLearnTask tsk;
tsk.SetParam("seed", "0"); tsk.SetParam("seed", "0");
int ret = tsk.Run(argc, argv); int ret = tsk.Run(argc, argv);