From 551b3b70f1710b30739f69cdf3c20bf87877d70c Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 29 Aug 2014 18:31:24 -0700 Subject: [PATCH 01/19] check unity back --- src/io/io.cpp | 2 + src/io/page_dmatrix-inl.hpp | 204 ++++++++++++++++++++++++ src/utils/io.h | 26 +++- src/utils/thread.h | 146 ++++++++++++++++++ src/utils/thread_buffer.h | 200 ++++++++++++++++++++++++ wrapper/xgboost.py | 298 ++++++++++++++++++++++++++++++++---- 6 files changed, 844 insertions(+), 32 deletions(-) create mode 100644 src/io/page_dmatrix-inl.hpp create mode 100644 src/utils/thread.h create mode 100644 src/utils/thread_buffer.h diff --git a/src/io/io.cpp b/src/io/io.cpp index d251d7a96..f56cff679 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -5,6 +5,8 @@ #include "../utils/io.h" #include "../utils/utils.h" #include "simple_dmatrix-inl.hpp" +#include "page_dmatrix-inl.hpp" + // implements data loads using dmatrix simple for now namespace xgboost { diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp new file mode 100644 index 000000000..82a373352 --- /dev/null +++ b/src/io/page_dmatrix-inl.hpp @@ -0,0 +1,204 @@ +#ifndef XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_ +#define XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_ +/*! + * \file page_row_iter-inl.hpp + * row iterator based on sparse page + * \author Tianqi Chen + */ +#include "../data.h" +#include "../utils/iterator.h" +#include "../utils/thread_buffer.h" +namespace xgboost { +namespace io { +/*! \brief page structure that can be used to store a rowbatch */ +struct RowBatchPage { + public: + RowBatchPage(void) { + data_ = new int[kPageSize]; + utils::Assert(data_ != NULL, "fail to allocate row batch page"); + this->Clear(); + } + ~RowBatchPage(void) { + if (data_ != NULL) delete [] data_; + } + /*! + * \brief Push one row into page + * \param row an instance row + * \return false or true to push into + */ + inline bool PushRow(const RowBatch::Inst &row) { + const size_t dsize = row.length * sizeof(RowBatch::Entry); + if (FreeBytes() < dsize+ sizeof(int)) return false; + row_ptr(Size() + 1) = row_ptr(Size()) + row.length; + memcpy(data_ptr(Size()) , row.data, dsize); + ++ data_[0]; + return true; + } + /*! + * \brief get a row batch representation from the page + * \param p_rptr a temporal space that can be used to provide + * ind_ptr storage for RowBatch + * \return a new RowBatch object + */ + inline RowBatch GetRowBatch(std::vector *p_rptr, size_t base_rowid) { + RowBatch batch; + batch.base_rowid = base_rowid; + batch.data_ptr = this->data_ptr(0); + batch.size = static_cast(this->Size()); + std::vector &rptr = *p_rptr; + rptr.resize(this->Size()+1); + for (size_t i = 0; i < rptr.size(); ++i) { + rptr[i] = static_cast(this->row_ptr(i)); + } + batch.ind_ptr = &rptr[0]; + return batch; + } + /*! + * \brief clear the page, cleanup the content + */ + inline void Clear(void) { + memset(&data_[0], 0, sizeof(int) * kPageSize); + } + /*! + * \brief load one page form instream + * \return true if loading is successful + */ + inline bool Load(utils::IStream &fi) { + return fi.Read(&data_[0], sizeof(int) * kPageSize) != 0; + } + /*! \brief save one page into outstream */ + inline void Save(utils::IStream &fo) { + fo.Write(&data_[0], sizeof(int) * kPageSize); + } + /*! \return number of elements */ + inline int Size(void) const { + return data_[0]; + } + /*! \brief page size 64 MB */ + static const size_t kPageSize = 64 << 18; + + private: + /*! \return number of elements */ + inline size_t FreeBytes(void) { + return (kPageSize - (Size() + 2)) * sizeof(int) + - row_ptr(Size()) * sizeof(RowBatch::Entry) ; + } + /*! \brief equivalent row pointer at i */ + inline int& row_ptr(int i) { + return data_[kPageSize - i - 1]; + } + inline RowBatch::Entry* data_ptr(int i) { + return (RowBatch::Entry*)(&data_[1]) + i; + } + // content of data + int *data_; +}; +/*! \brief thread buffer iterator */ +class ThreadRowPageIterator: public utils::IIterator { + public: + ThreadRowPageIterator(void) { + itr.SetParam("buffer_size", "4"); + page_ = NULL; + base_rowid_ = 0; + isend_ = false; + } + virtual ~ThreadRowPageIterator(void) { + } + virtual void Init(void) { + } + virtual void BeforeFirst(void) { + itr.BeforeFirst(); + isend_ = false; + base_rowid_ = 0; + utils::Assert(this->LoadNextPage(), "ThreadRowPageIterator"); + } + virtual bool Next(void) { + if(!this->LoadNextPage()) return false; + out_ = page_->GetRowBatch(&tmp_ptr_, base_rowid_); + base_rowid_ += out_.size; + return true; + } + virtual const RowBatch &Value(void) const{ + return out_; + } + /*! \brief load and initialize the iterator with fi */ + inline void Load(const utils::FileStream &fi) { + itr.get_factory().SetFile(fi); + itr.Init(); + this->BeforeFirst(); + } + /*! + * \brief save a row iterator to output stream, in row iterator format + */ + inline static void Save(utils::IIterator *iter, + utils::IStream &fo) { + RowBatchPage page; + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch &batch = iter->Value(); + for (size_t i = 0; i < batch.size; ++i) { + if (!page.PushRow(batch[i])) { + page.Save(fo); + page.Clear(); + utils::Check(page.PushRow(batch[i]), "row is too big"); + } + } + } + if (page.Size() != 0) page.Save(fo); + } + private: + // load in next page + inline bool LoadNextPage(void) { + ptop_ = 0; + bool ret = itr.Next(page_); + isend_ = !ret; + return ret; + } + // base row id + size_t base_rowid_; + // temporal ptr + std::vector tmp_ptr_; + // output data + RowBatch out_; + // whether we reach end of file + bool isend_; + // page pointer type + typedef RowBatchPage* PagePtr; + // loader factory for page + struct Factory { + public: + size_t file_begin_; + utils::FileStream fi; + Factory(void) {} + inline void SetFile(const utils::FileStream &fi) { + this->fi = fi; + file_begin_ = this->fi.Tell(); + } + inline bool Init(void) { + return true; + } + inline void SetParam(const char *name, const char *val) {} + inline bool LoadNext(PagePtr &val) { + return val->Load(fi); + } + inline PagePtr Create(void) { + PagePtr a = new RowBatchPage(); + return a; + } + inline void FreeSpace(PagePtr &a) { + delete a; + } + inline void Destroy(void) {} + inline void BeforeFirst(void) { + fi.Seek(file_begin_); + } + }; + + protected: + PagePtr page_; + int ptop_; + utils::ThreadBuffer itr; +}; +} // namespace io +} // namespace xgboost +#endif // XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_ diff --git a/src/utils/io.h b/src/utils/io.h index 4a80e9a58..23fa0d468 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -88,11 +88,19 @@ class IStream { } }; -/*! \brief implementation of file i/o stream */ -class FileStream : public IStream { - private: - FILE *fp; +/*! \brief interface of i/o stream that support seek */ +class ISeekStream: public IStream { public: + /*! \brief seek to certain position of the file */ + virtual void Seek(size_t pos) = 0; + /*! \brief tell the position of the stream */ + virtual size_t Tell(void) = 0; +}; + +/*! \brief implementation of file i/o stream */ +class FileStream : public ISeekStream { + public: + explicit FileStream(void) {} explicit FileStream(FILE *fp) { this->fp = fp; } @@ -102,12 +110,18 @@ class FileStream : public IStream { virtual void Write(const void *ptr, size_t size) { fwrite(ptr, size, 1, fp); } - inline void Seek(size_t pos) { - fseek(fp, 0, SEEK_SET); + virtual void Seek(size_t pos) { + fseek(fp, pos, SEEK_SET); + } + virtual size_t Tell(void) { + return static_cast(ftell(fp)); } inline void Close(void) { fclose(fp); } + + private: + FILE *fp; }; } // namespace utils diff --git a/src/utils/thread.h b/src/utils/thread.h new file mode 100644 index 000000000..830b21cbe --- /dev/null +++ b/src/utils/thread.h @@ -0,0 +1,146 @@ +#ifndef XGBOOST_UTILS_THREAD_H +#define XGBOOST_UTILS_THREAD_H +/*! + * \file thread.h + * \brief this header include the minimum necessary resource for multi-threading + * \author Tianqi Chen + * Acknowledgement: this file is adapted from SVDFeature project, by same author. + * The MAC support part of this code is provided by Artemy Kolchinsky + */ +#ifdef _MSC_VER +#include "utils.h" +#include +#include +namespace xgboost { +namespace utils { +/*! \brief simple semaphore used for synchronization */ +class Semaphore { + public : + inline void Init(int init_val) { + sem = CreateSemaphore(NULL, init_val, 10, NULL); + utils::Assert(sem != NULL, "create Semaphore error"); + } + inline void Destroy(void) { + CloseHandle(sem); + } + inline void Wait(void) { + utils::Assert(WaitForSingleObject(sem, INFINITE) == WAIT_OBJECT_0, "WaitForSingleObject error"); + } + inline void Post(void) { + utils::Assert(ReleaseSemaphore(sem, 1, NULL) != 0, "ReleaseSemaphore error"); + } + private: + HANDLE sem; +}; +/*! \brief simple thread that wraps windows thread */ +class Thread { + private: + HANDLE thread_handle; + unsigned thread_id; + public: + inline void Start(unsigned int __stdcall entry(void*), void *param) { + thread_handle = (HANDLE)_beginthreadex(NULL, 0, entry, param, 0, &thread_id); + } + inline int Join(void) { + WaitForSingleObject(thread_handle, INFINITE); + return 0; + } +}; +/*! \brief exit function called from thread */ +inline void ThreadExit(void *status) { + _endthreadex(0); +} +#define XGBOOST_THREAD_PREFIX unsigned int __stdcall +} // namespace utils +} // namespace xgboost +#else +// thread interface using g++ +#include +#include +namespace xgboost { +namespace utils { +/*!\brief semaphore class */ +class Semaphore { + #ifdef __APPLE__ + private: + sem_t* semPtr; + char sema_name[20]; + private: + inline void GenRandomString(char *s, const int len) { + static const char alphanum[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" ; + for (int i = 0; i < len; ++i) { + s[i] = alphanum[rand() % (sizeof(alphanum) - 1)]; + } + s[len] = 0; + } + public: + inline void Init(int init_val) { + sema_name[0]='/'; + sema_name[1]='s'; + sema_name[2]='e'; + sema_name[3]='/'; + GenRandomString(&sema_name[4], 16); + if((semPtr = sem_open(sema_name, O_CREAT, 0644, init_val)) == SEM_FAILED) { + perror("sem_open"); + exit(1); + } + utils::Assert(semPtr != NULL, "create Semaphore error"); + } + inline void Destroy(void) { + if (sem_close(semPtr) == -1) { + perror("sem_close"); + exit(EXIT_FAILURE); + } + if (sem_unlink(sema_name) == -1) { + perror("sem_unlink"); + exit(EXIT_FAILURE); + } + } + inline void Wait(void) { + sem_wait(semPtr); + } + inline void Post(void) { + sem_post(semPtr); + } + #else + private: + sem_t sem; + public: + inline void Init(int init_val) { + sem_init(&sem, 0, init_val); + } + inline void Destroy(void) { + sem_destroy(&sem); + } + inline void Wait(void) { + sem_wait(&sem); + } + inline void Post(void) { + sem_post(&sem); + } + #endif +}; +/*!\brief simple thread class */ +class Thread { + private: + pthread_t thread; + public : + inline void Start(void * entry(void*), void *param) { + pthread_attr_t attr; + pthread_attr_init(&attr); + pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE); + pthread_create(&thread, &attr, entry, param); + } + inline int Join(void) { + void *status; + return pthread_join(thread, &status); + } +}; +inline void ThreadExit(void *status) { + pthread_exit(status); +} +} // namespace utils +} // namespace xgboost +#define XGBOOST_THREAD_PREFIX void * +#endif +#endif diff --git a/src/utils/thread_buffer.h b/src/utils/thread_buffer.h new file mode 100644 index 000000000..fa488a220 --- /dev/null +++ b/src/utils/thread_buffer.h @@ -0,0 +1,200 @@ +#ifndef XGBOOST_UTILS_THREAD_BUFFER_H +#define XGBOOST_UTILS_THREAD_BUFFER_H +/*! + * \file thread_buffer.h + * \brief multi-thread buffer, iterator, can be used to create parallel pipeline + * \author Tianqi Chen + */ +#include +#include +#include +#include "./utils.h" +#include "./thread.h" +namespace xgboost { +namespace utils { +/*! + * \brief buffered loading iterator that uses multithread + * this template method will assume the following paramters + * \tparam Elem elememt type to be buffered + * \tparam ElemFactory factory type to implement in order to use thread buffer + */ +template +class ThreadBuffer { + public: + /*!\brief constructor */ + ThreadBuffer(void) { + this->init_end = false; + this->buf_size = 30; + } + ~ThreadBuffer(void) { + if(init_end) this->Destroy(); + } + /*!\brief set parameter, will also pass the parameter to factory */ + inline void SetParam(const char *name, const char *val) { + if (!strcmp( name, "buffer_size")) buf_size = atoi(val); + factory.SetParam(name, val); + } + /*! + * \brief initalize the buffered iterator + * \param param a initialize parameter that will pass to factory, ignore it if not necessary + * \return false if the initlization can't be done, e.g. buffer file hasn't been created + */ + inline bool Init(void) { + if (!factory.Init()) return false; + for (int i = 0; i < buf_size; ++i) { + bufA.push_back(factory.Create()); + bufB.push_back(factory.Create()); + } + this->init_end = true; + this->StartLoader(); + return true; + } + /*!\brief place the iterator before first value */ + inline void BeforeFirst(void) { + // wait till last loader end + loading_end.Wait(); + // critcal zone + current_buf = 1; + factory.BeforeFirst(); + // reset terminate limit + endA = endB = buf_size; + // wake up loader for first part + loading_need.Post(); + // wait til first part is loaded + loading_end.Wait(); + // set current buf to right value + current_buf = 0; + // wake loader for next part + data_loaded = false; + loading_need.Post(); + // set buffer value + buf_index = 0; + } + /*! \brief destroy the buffer iterator, will deallocate the buffer */ + inline void Destroy(void) { + // wait until the signal is consumed + this->destroy_signal = true; + loading_need.Post(); + loader_thread.Join(); + loading_need.Destroy(); + loading_end.Destroy(); + for (size_t i = 0; i < bufA.size(); ++i) { + factory.FreeSpace(bufA[i]); + } + for (size_t i = 0; i < bufB.size(); ++i) { + factory.FreeSpace(bufB[i]); + } + bufA.clear(); bufB.clear(); + factory.Destroy(); + this->init_end = false; + } + /*! + * \brief get the next element needed in buffer + * \param elem element to store into + * \return whether reaches end of data + */ + inline bool Next(Elem &elem) { + // end of buffer try to switch + if (buf_index == buf_size) { + this->SwitchBuffer(); + buf_index = 0; + } + if (buf_index >= (current_buf ? endA : endB)) { + return false; + } + std::vector &buf = current_buf ? bufA : bufB; + elem = buf[buf_index]; + ++buf_index; + return true; + } + /*! + * \brief get the factory object + */ + inline ElemFactory &get_factory(void) { + return factory; + } + // size of buffer + int buf_size; + private: + // factory object used to load configures + ElemFactory factory; + // index in current buffer + int buf_index; + // indicate which one is current buffer + int current_buf; + // max limit of visit, also marks termination + int endA, endB; + // double buffer, one is accessed by loader + // the other is accessed by consumer + // buffer of the data + std::vector bufA, bufB; + // initialization end + bool init_end; + // singal whether the data is loaded + bool data_loaded; + // signal to kill the thread + bool destroy_signal; + // thread object + Thread loader_thread; + // signal of the buffer + Semaphore loading_end, loading_need; + /*! + * \brief slave thread + * this implementation is like producer-consumer style + */ + inline void RunLoader(void) { + while(!destroy_signal) { + // sleep until loading is needed + loading_need.Wait(); + std::vector &buf = current_buf ? bufB : bufA; + int i; + for (i = 0; i < buf_size ; ++i) { + if (!factory.LoadNext(buf[i])) { + int &end = current_buf ? endB : endA; + end = i; // marks the termination + break; + } + } + // signal that loading is done + data_loaded = true; + loading_end.Post(); + } + } + /*!\brief entry point of loader thread */ + inline static XGBOOST_THREAD_PREFIX LoaderEntry(void *pthread) { + static_cast< ThreadBuffer* >(pthread)->RunLoader(); + ThreadExit(NULL); + return NULL; + } + /*!\brief start loader thread */ + inline void StartLoader(void) { + destroy_signal = false; + // set param + current_buf = 1; + loading_need.Init(1); + loading_end .Init(0); + // reset terminate limit + endA = endB = buf_size; + loader_thread.Start(LoaderEntry, this); + // wait until first part of data is loaded + loading_end.Wait(); + // set current buf to right value + current_buf = 0; + // wake loader for next part + data_loaded = false; + loading_need.Post(); + buf_index = 0; + } + /*!\brief switch double buffer */ + inline void SwitchBuffer(void) { + loading_end.Wait(); + // loader shall be sleep now, critcal zone! + current_buf = !current_buf; + // wake up loader + data_loaded = false; + loading_need.Post(); + } +}; +} // namespace utils +} // namespace xgboost +#endif diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index e2cbdba2e..adf59c829 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -3,10 +3,11 @@ import ctypes import os # optinally have scipy sparse, though not necessary -import numpy +import numpy as np import sys import numpy.ctypeslib import scipy.sparse as scp +import random # set this line correctly if os.name == 'nt': @@ -32,16 +33,28 @@ xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p) def ctypes2numpy(cptr, length, dtype): - # convert a ctypes pointer array to numpy + """convert a ctypes pointer array to numpy array """ assert isinstance(cptr, ctypes.POINTER(ctypes.c_float)) res = numpy.zeros(length, dtype=dtype) assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]) return res -# data matrix used in xgboost class DMatrix: + """data matrix used in xgboost""" # constructor def __init__(self, data, label=None, missing=0.0, weight = None): + """ constructor of DMatrix + + Args: + data: string/numpy array/scipy.sparse + data source, string type is the path of svmlight format txt file or xgb buffer + label: list or numpy 1d array, optional + label of training data + missing: float + value in data which need to be present as missing value + weight: list or numpy 1d array, optional + weight for each instances + """ # force into void_p, mac need to pass things in as void_p if data == None: self.handle = None @@ -63,22 +76,25 @@ class DMatrix: self.set_label(label) if weight !=None: self.set_weight(weight) - # convert data from csr matrix + def __init_from_csr(self, csr): + """convert data from csr matrix""" assert len(csr.indices) == len(csr.data) self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR( (ctypes.c_ulong * len(csr.indptr))(*csr.indptr), (ctypes.c_uint * len(csr.indices))(*csr.indices), (ctypes.c_float * len(csr.data))(*csr.data), len(csr.indptr), len(csr.data))) - # convert data from numpy matrix + def __init_from_npy2d(self,mat,missing): + """convert data from numpy matrix""" data = numpy.array(mat.reshape(mat.size), dtype='float32') self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat( data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), mat.shape[0], mat.shape[1], ctypes.c_float(missing))) - # destructor + def __del__(self): + """destructor""" xglib.XGDMatrixFree(self.handle) def get_float_info(self, field): length = ctypes.c_ulong() @@ -96,16 +112,39 @@ class DMatrix: def set_uint_info(self, field, data): xglib.XGDMatrixSetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), (ctypes.c_uint*len(data))(*data), len(data)) - # load data from file + def save_binary(self, fname, silent=True): + """save DMatrix to XGBoost buffer + Args: + fname: string + name of buffer file + slient: bool, option + whether print info + Returns: + None + """ xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent)) - # set label of dmatrix + def set_label(self, label): + """set label of dmatrix + Args: + label: list + label for DMatrix + Returns: + None + """ self.set_float_info('label', label) - # set weight of each instances + def set_weight(self, weight): + """set weight of each instances + Args: + weight: float + weight for positive instance + Returns: + None + """ self.set_float_info('weight', weight) - # set initialized margin prediction + def set_base_margin(self, margin): """ set base margin of booster to start from @@ -116,31 +155,143 @@ class DMatrix: see also example/demo.py """ self.set_float_info('base_margin', margin) - # set group size of dmatrix, used for rank + def set_group(self, group): + """set group size of dmatrix, used for rank + Args: + group: + + Returns: + None + """ xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group)) - # get label from dmatrix + def get_label(self): + """get label from dmatrix + Args: + None + Returns: + list, label of data + """ return self.get_float_info('label') - # get weight from dmatrix + def get_weight(self): + """get weight from dmatrix + Args: + None + Returns: + float, weight + """ return self.get_float_info('weight') - # get base_margin from dmatrix def get_base_margin(self): + """get base_margin from dmatrix + Args: + None + Returns: + float, base margin + """ return self.get_float_info('base_margin') def num_row(self): + """get number of rows + Args: + None + Returns: + int, num rows + """ return xglib.XGDMatrixNumRow(self.handle) - # slice the DMatrix to return a new DMatrix that only contains rindex def slice(self, rindex): + """slice the DMatrix to return a new DMatrix that only contains rindex + Args: + rindex: list + list of index to be chosen + Returns: + res: DMatrix + new DMatrix with chosen index + """ res = DMatrix(None) res.handle = ctypes.c_void_p(xglib.XGDMatrixSliceDMatrix( self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex))) return res +class CVPack: + def __init__(self, dtrain, dtest, param): + self.dtrain = dtrain + self.dtest = dtest + self.watchlist = watchlist = [ (dtrain,'train'), (dtest, 'test') ] + self.bst = Booster(param, [dtrain,dtest]) + def update(self,r): + self.bst.update(self.dtrain, r) + def eval(self,r): + return self.bst.eval_set(self.watchlist, r) + +def mknfold(dall, nfold, param, seed, weightscale=None): + """ + mk nfold list of cvpack from randidx + """ + randidx = range(dall.num_row()) + random.seed(seed) + random.shuffle(randidx) + + idxset = [] + kstep = len(randidx) / nfold + for i in range(nfold): + idxset.append(randidx[ (i*kstep) : min(len(randidx),(i+1)*kstep) ]) + + ret = [] + for k in range(nfold): + trainlst = [] + for j in range(nfold): + if j == k: + testlst = idxset[j] + else: + trainlst += idxset[j] + dtrain = dall.slice(trainlst) + dtest = dall.slice(testlst) + # rescale weight of dtrain and dtest + if weightscale != None: + dtrain.set_weight( dtrain.get_weight() * weightscale * dall.num_row() / dtrain.num_row() ) + dtest.set_weight( dtest.get_weight() * weightscale * dall.num_row() / dtest.num_row() ) + + ret.append(CVPack(dtrain, dtest, param)) + return ret + +def aggcv(rlist): + """ + aggregate cross validation results + """ + cvmap = {} + arr = rlist[0].split() + ret = arr[0] + for it in arr[1:]: + k, v = it.split(':') + cvmap[k] = [float(v)] + for line in rlist[1:]: + arr = line.split() + assert ret == arr[0] + for it in arr[1:]: + k, v = it.split(':') + cvmap[k].append(float(v)) + + for k, v in sorted(cvmap.items(), key = lambda x:x[0]): + v = np.array(v) + ret += '\t%s:%f+%f' % (k, np.mean(v), np.std(v)) + return ret + + class Booster: """learner class """ def __init__(self, params={}, cache=[], model_file = None): - """ constructor, param: """ + """ constructor + Args: + params: dict + params for boosters + cache: list + list of cache item + model_file: string + path of model file + Returns: + None + """ for d in cache: assert isinstance(d, DMatrix) dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache]) @@ -166,16 +317,30 @@ class Booster: xglib.XGBoosterSetParam( self.handle, ctypes.c_char_p(k.encode('utf-8')), ctypes.c_char_p(str(v).encode('utf-8'))) + def update(self, dtrain, it): """ update - dtrain: the training DMatrix - it: current iteration number + Args: + dtrain: DMatrix + the training DMatrix + it: int + current iteration number + Returns: + None """ assert isinstance(dtrain, DMatrix) xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle) def boost(self, dtrain, grad, hess): - """ update """ + """ update + Args: + dtrain: DMatrix + the training DMatrix + grad: list + the first order of gradient + hess: list + the second order of gradient + """ assert len(grad) == len(hess) assert isinstance(dtrain, DMatrix) xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle, @@ -183,6 +348,14 @@ class Booster: (ctypes.c_float*len(hess))(*hess), len(grad)) def eval_set(self, evals, it = 0): + """evaluates by metric + Args: + evals: list of tuple (DMatrix, string) + lists of items to be evaluated + it: int + Returns: + evals result + """ for d in evals: assert isinstance(d[0], DMatrix) assert isinstance(d[1], str) @@ -195,21 +368,46 @@ class Booster: def predict(self, data, output_margin=False): """ predict with data - data: the dmatrix storing the input - output_margin: whether output raw margin value that is untransformed + Args: + data: DMatrix + the dmatrix storing the input + output_margin: bool + whether output raw margin value that is untransformed + Returns: + numpy array of prediction """ length = ctypes.c_ulong() preds = xglib.XGBoosterPredict(self.handle, data.handle, int(output_margin), ctypes.byref(length)) return ctypes2numpy(preds, length.value, 'float32') def save_model(self, fname): - """ save model to file """ + """ save model to file + Args: + fname: string + file name of saving model + Returns: + None + """ xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8'))) def load_model(self, fname): - """load model from file""" + """load model from file + Args: + fname: string + file name of saving model + Returns: + None + """ xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) ) def dump_model(self, fo, fmap=''): - """dump model into text file""" + """dump model into text file + Args: + fo: string + file name to be dumped + fmap: string, optional + file name of feature map names + Returns: + None + """ if isinstance(fo,str): fo = open(fo,'w') need_close = True @@ -248,7 +446,17 @@ class Booster: return fmap def evaluate(bst, evals, it, feval = None): - """evaluation on eval set""" + """evaluation on eval set + Args: + bst: XGBoost object + object of XGBoost model + evals: list of tuple (DMatrix, string) + obj need to be evaluated + it: int + feval: optional + Returns: + eval result + """ if feval != None: res = '[%d]' % it for dm, evname in evals: @@ -259,8 +467,22 @@ def evaluate(bst, evals, it, feval = None): return res + + def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None): - """ train a booster with given paramaters """ + """ train a booster with given paramaters + Args: + params: dict + params of booster + dtrain: DMatrix + data to be trained + num_boost_round: int + num of round to be boosted + evals: list + list of items to be evaluated + obj: + feval: + """ bst = Booster(params, [dtrain]+[ d[0] for d in evals ] ) if obj == None: for i in range(num_boost_round): @@ -276,3 +498,27 @@ def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None if len(evals) != 0: sys.stderr.write(evaluate(bst, evals, i, feval)+'\n') return bst + +def cv(params, dtrain, num_boost_round = 10, nfold=3, evals = [], obj=None, feval=None): + """ cross validation with given paramaters + Args: + params: dict + params of booster + dtrain: DMatrix + data to be trained + num_boost_round: int + num of round to be boosted + nfold: int + folds to do cv + evals: list + list of items to be evaluated + obj: + feval: + """ + plst = list(params.items())+[('eval_metric', itm) for itm in evals] + cvfolds = mknfold(dtrain, nfold, plst, 0) + for i in range(num_boost_round): + for f in cvfolds: + f.update(i) + res = aggcv([f.eval(i) for f in cvfolds]) + sys.stderr.write(res+'\n') From ce2d34ecd4f4f4f7a965d997413144d6cd4d6f7a Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 29 Aug 2014 18:35:26 -0700 Subject: [PATCH 02/19] check unity back --- src/io/page_dmatrix-inl.hpp | 2 +- src/tree/param.h | 6 +- src/tree/updater_colmaker-inl.hpp | 194 ++++++++++++++++++++++++++---- src/utils/io.h | 1 - 4 files changed, 177 insertions(+), 26 deletions(-) diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 82a373352..df43d3b7f 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -18,7 +18,7 @@ struct RowBatchPage { utils::Assert(data_ != NULL, "fail to allocate row batch page"); this->Clear(); } - ~RowBatchPage(void) { + ~BinaryPage(void) { if (data_ != NULL) delete [] data_; } /*! diff --git a/src/tree/param.h b/src/tree/param.h index 52c273749..92bc1c990 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -37,7 +37,9 @@ struct TrainParam{ // speed optimization for dense column float opt_dense_col; // leaf vector size - int size_leaf_vector; + int size_leaf_vector; + // option for parallelization + int parallel_option; // number of threads to be used for tree construction, // if OpenMP is enabled, if equals 0, use system default int nthread; @@ -55,6 +57,7 @@ struct TrainParam{ opt_dense_col = 1.0f; nthread = 0; size_leaf_vector = 0; + parallel_option = 0; } /*! * \brief set parameters from outside @@ -79,6 +82,7 @@ struct TrainParam{ if (!strcmp(name, "size_leaf_vector")) size_leaf_vector = atoi(val); if (!strcmp(name, "max_depth")) max_depth = atoi(val); if (!strcmp(name, "nthread")) nthread = atoi(val); + if (!strcmp(name, "parallel_option")) parallel_option = atoi(val); if (!strcmp(name, "default_direction")) { if (!strcmp(val, "learn")) default_direction = 0; if (!strcmp(val, "left")) default_direction = 1; diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index a8cf6ea7f..bf93cb7b5 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -45,15 +45,19 @@ class ColMaker: public IUpdater { // data structure /*! \brief per thread x per node entry to store tmp data */ struct ThreadEntry { - /*! \brief statistics of data*/ + /*! \brief statistics of data */ TStats stats; + /*! \brief extra statistics of data */ + TStats stats_extra; /*! \brief last feature value scanned */ float last_fvalue; + /*! \brief first feature value scanned */ + float first_fvalue; /*! \brief current best solution */ SplitEntry best; // constructor explicit ThreadEntry(const TrainParam ¶m) - : stats(param) { + : stats(param), stats_extra(param) { } }; struct NodeEntry { @@ -219,7 +223,137 @@ class ColMaker: public IUpdater { } // use new nodes for qexpand qexpand = newnodes; - } + } + // parallel find the best split of current fid + // this function does not support nested functions + inline void ParallelFindSplit(const ColBatch::Inst &col, + bst_uint fid, + const IFMatrix &fmat, + const std::vector &gpair, + const BoosterInfo &info) { + bool need_forward = param.need_forward_search(fmat.GetColDensity(fid)); + bool need_backward = param.need_backward_search(fmat.GetColDensity(fid)); + int nthread; + #pragma omp parallel + { + const int tid = omp_get_thread_num(); + std::vector &temp = stemp[tid]; + // cleanup temp statistics + for (size_t j = 0; j < qexpand.size(); ++j) { + temp[qexpand[j]].stats.Clear(); + } + nthread = omp_get_num_threads(); + bst_uint step = (col.length + nthread - 1) / nthread; + bst_uint end = std::min(col.length, step * (tid + 1)); + for (bst_uint i = tid * step; i < end; ++i) { + const bst_uint ridx = col[i].index; + const int nid = position[ridx]; + if (nid < 0) continue; + const float fvalue = col[i].fvalue; + if (temp[nid].stats.Empty()) { + temp[nid].first_fvalue = fvalue; + } + temp[nid].stats.Add(gpair, info, ridx); + temp[nid].last_fvalue = fvalue; + } + } + // start collecting the partial sum statistics + bst_omp_uint nnode = static_cast(qexpand.size()); + #pragma omp parallel for schedule(static) + for (bst_omp_uint j = 0; j < nnode; ++j) { + const int nid = qexpand[j]; + TStats sum(param), tmp(param), c(param); + for (int tid = 0; tid < nthread; ++tid) { + tmp = stemp[tid][nid].stats; + stemp[tid][nid].stats = sum; + sum.Add(tmp); + if (tid != 0) { + std::swap(stemp[tid - 1][nid].last_fvalue, stemp[tid][nid].first_fvalue); + } + } + for (int tid = 0; tid < nthread; ++tid) { + stemp[tid][nid].stats_extra = sum; + ThreadEntry &e = stemp[tid][nid]; + float fsplit; + if (tid != 0) { + if(fabsf(stemp[tid - 1][nid].last_fvalue - e.first_fvalue) > rt_2eps) { + fsplit = (stemp[tid - 1][nid].last_fvalue - e.first_fvalue) * 0.5f; + } else { + continue; + } + } else { + fsplit = e.first_fvalue - rt_eps; + } + if (need_forward && tid != 0) { + c.SetSubstract(snode[nid].stats, e.stats); + if (c.sum_hess >= param.min_child_weight && e.stats.sum_hess >= param.min_child_weight) { + bst_float loss_chg = static_cast(e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain); + e.best.Update(loss_chg, fid, fsplit, false); + } + } + if (need_backward) { + tmp.SetSubstract(sum, e.stats); + c.SetSubstract(snode[nid].stats, tmp); + if (c.sum_hess >= param.min_child_weight && tmp.sum_hess >= param.min_child_weight) { + bst_float loss_chg = static_cast(tmp.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain); + e.best.Update(loss_chg, fid, fsplit, true); + } + } + } + if (need_backward) { + tmp = sum; + ThreadEntry &e = stemp[nthread-1][nid]; + c.SetSubstract(snode[nid].stats, tmp); + if (c.sum_hess >= param.min_child_weight && tmp.sum_hess >= param.min_child_weight) { + bst_float loss_chg = static_cast(tmp.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain); + e.best.Update(loss_chg, fid, e.last_fvalue + rt_eps, true); + } + } + } + // rescan, generate candidate split + #pragma omp parallel + { + TStats c(param), cright(param); + const int tid = omp_get_thread_num(); + std::vector &temp = stemp[tid]; + nthread = static_cast(omp_get_num_threads()); + bst_uint step = (col.length + nthread - 1) / nthread; + bst_uint end = std::min(col.length, step * (tid + 1)); + for (bst_uint i = tid * step; i < end; ++i) { + const bst_uint ridx = col[i].index; + const int nid = position[ridx]; + if (nid < 0) continue; + const float fvalue = col[i].fvalue; + // get the statistics of nid + ThreadEntry &e = temp[nid]; + if (e.stats.Empty()) { + e.stats.Add(gpair, info, ridx); + e.first_fvalue = fvalue; + } else { + // forward default right + if (fabsf(fvalue - e.first_fvalue) > rt_2eps){ + if (need_forward) { + c.SetSubstract(snode[nid].stats, e.stats); + if (c.sum_hess >= param.min_child_weight && e.stats.sum_hess >= param.min_child_weight) { + bst_float loss_chg = static_cast(e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain); + e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, false); + } + } + if (need_backward) { + cright.SetSubstract(e.stats_extra, e.stats); + c.SetSubstract(snode[nid].stats, cright); + if (c.sum_hess >= param.min_child_weight && cright.sum_hess >= param.min_child_weight) { + bst_float loss_chg = static_cast(cright.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain); + e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true); + } + } + } + e.stats.Add(gpair, info, ridx); + e.first_fvalue = fvalue; + } + } + } + } // enumerate the split values of specific feature inline void EnumerateSplit(const ColBatch::Entry *begin, const ColBatch::Entry *end, @@ -272,6 +406,38 @@ class ColMaker: public IUpdater { } } } + // update the solution candidate + virtual void UpdateSolution(const ColBatch &batch, + const std::vector &gpair, + const IFMatrix &fmat, + const BoosterInfo &info) { + // start enumeration + const bst_omp_uint nsize = static_cast(batch.size); + #if defined(_OPENMP) + const int batch_size = std::max(static_cast(nsize / this->nthread / 32), 1); + #endif + if (param.parallel_option == 0) { + #pragma omp parallel for schedule(dynamic, batch_size) + for (bst_omp_uint i = 0; i < nsize; ++i) { + const bst_uint fid = batch.col_index[i]; + const int tid = omp_get_thread_num(); + const ColBatch::Inst c = batch[i]; + if (param.need_forward_search(fmat.GetColDensity(fid))) { + this->EnumerateSplit(c.data, c.data + c.length, +1, + fid, gpair, info, stemp[tid]); + } + if (param.need_backward_search(fmat.GetColDensity(fid))) { + this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1, + fid, gpair, info, stemp[tid]); + } + } + } else { + for (bst_omp_uint i = 0; i < nsize; ++i) { + this->ParallelFindSplit(batch[i], batch.col_index[i], + fmat, gpair, info); + } + } + } // find splits at current level, do split per level inline void FindSplit(int depth, const std::vector &qexpand, @@ -288,26 +454,7 @@ class ColMaker: public IUpdater { } utils::IIterator *iter = p_fmat->ColIterator(feat_set); while (iter->Next()) { - const ColBatch &batch = iter->Value(); - // start enumeration - const bst_omp_uint nsize = static_cast(batch.size); - #if defined(_OPENMP) - const int batch_size = std::max(static_cast(nsize / this->nthread / 32), 1); - #endif - #pragma omp parallel for schedule(dynamic, batch_size) - for (bst_omp_uint i = 0; i < nsize; ++i) { - const bst_uint fid = batch.col_index[i]; - const int tid = omp_get_thread_num(); - const ColBatch::Inst c = batch[i]; - if (param.need_forward_search(p_fmat->GetColDensity(fid))) { - this->EnumerateSplit(c.data, c.data + c.length, +1, - fid, gpair, info, stemp[tid]); - } - if (param.need_backward_search(p_fmat->GetColDensity(fid))) { - this->EnumerateSplit(c.data + c.length - 1, c.data - 1, -1, - fid, gpair, info, stemp[tid]); - } - } + this->UpdateSolution(iter->Value(), gpair, *p_fmat, info); } // after this each thread's stemp will get the best candidates, aggregate results for (size_t i = 0; i < qexpand.size(); ++i) { @@ -325,6 +472,7 @@ class ColMaker: public IUpdater { } } } + // reset position of each data points after split is created in the tree inline void ResetPosition(const std::vector &qexpand, IFMatrix *p_fmat, const RegTree &tree) { const std::vector &rowset = p_fmat->buffered_rowset(); diff --git a/src/utils/io.h b/src/utils/io.h index 23fa0d468..141d83f8c 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -100,7 +100,6 @@ class ISeekStream: public IStream { /*! \brief implementation of file i/o stream */ class FileStream : public ISeekStream { public: - explicit FileStream(void) {} explicit FileStream(FILE *fp) { this->fp = fp; } From d0e27482efe2917c60977f46fedf2836e6e18744 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 29 Aug 2014 18:44:02 -0700 Subject: [PATCH 03/19] fix compiler error --- src/io/page_dmatrix-inl.hpp | 2 +- src/utils/io.h | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index df43d3b7f..82a373352 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -18,7 +18,7 @@ struct RowBatchPage { utils::Assert(data_ != NULL, "fail to allocate row batch page"); this->Clear(); } - ~BinaryPage(void) { + ~RowBatchPage(void) { if (data_ != NULL) delete [] data_; } /*! diff --git a/src/utils/io.h b/src/utils/io.h index 141d83f8c..23fa0d468 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -100,6 +100,7 @@ class ISeekStream: public IStream { /*! \brief implementation of file i/o stream */ class FileStream : public ISeekStream { public: + explicit FileStream(void) {} explicit FileStream(FILE *fp) { this->fp = fp; } From ce772c2f3e9cfcc06ea44287c6825a0e08f0efb0 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 29 Aug 2014 19:59:19 -0700 Subject: [PATCH 04/19] first check of page --- src/io/io.cpp | 13 +++++++- src/io/page_dmatrix-inl.hpp | 65 +++++++++++++++++++++++++++++++++++-- src/utils/io.h | 8 +++-- 3 files changed, 81 insertions(+), 5 deletions(-) diff --git a/src/io/io.cpp b/src/io/io.cpp index f56cff679..e8b9ce337 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -22,7 +22,13 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { dmat->LoadBinary(fs, silent, fname); fs.Close(); return dmat; - } + } + if (magic == DMatrixPage::kMagic) { + DMatrixPage *dmat = new DMatrixPage(); + dmat->Load(fs, silent, fname); + // the file pointer is hold in page matrix + return dmat; + } fs.Close(); DMatrixSimple *dmat = new DMatrixSimple(); @@ -31,6 +37,11 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { } void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) { + if (!strcmp(fname + strlen(fname) - 5, ".page")) { + + DMatrixPage::Save(fname, dmat, silent); + return; + } if (dmat.magic == DMatrixSimple::kMagic) { const DMatrixSimple *p_dmat = static_cast(&dmat); p_dmat->SaveBinary(fname, silent); diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 82a373352..4d13a0bc5 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -8,6 +8,8 @@ #include "../data.h" #include "../utils/iterator.h" #include "../utils/thread_buffer.h" +#include "./simple_fmatrix-inl.hpp" + namespace xgboost { namespace io { /*! \brief page structure that can be used to store a rowbatch */ @@ -102,7 +104,7 @@ class ThreadRowPageIterator: public utils::IIterator { base_rowid_ = 0; isend_ = false; } - virtual ~ThreadRowPageIterator(void) { + virtual ~ThreadRowPageIterator(void) { } virtual void Init(void) { } @@ -188,7 +190,9 @@ class ThreadRowPageIterator: public utils::IIterator { inline void FreeSpace(PagePtr &a) { delete a; } - inline void Destroy(void) {} + inline void Destroy(void) { + fi.Close(); + } inline void BeforeFirst(void) { fi.Seek(file_begin_); } @@ -199,6 +203,63 @@ class ThreadRowPageIterator: public utils::IIterator { int ptop_; utils::ThreadBuffer itr; }; + +/*! \brief data matrix using page */ +class DMatrixPage : public DataMatrix { + public: + DMatrixPage(void) : DataMatrix(kMagic) { + iter_ = new ThreadRowPageIterator(); + fmat_ = new FMatrixS(iter_); + } + // virtual destructor + virtual ~DMatrixPage(void) { + delete fmat_; + } + virtual IFMatrix *fmat(void) const { + return fmat_; + } + /*! \brief load and initialize the iterator with fi */ + inline void Load(utils::FileStream &fi, + bool silent = false, + const char *fname = NULL){ + int magic; + utils::Check(fi.Read(&magic, sizeof(magic)) != 0, "invalid input file format"); + utils::Check(magic == kMagic, "invalid format,magic number mismatch"); + this->info.LoadBinary(fi); + iter_->Load(fi); + if (!silent) { + printf("DMatrixPage: %lux%lu matrix is loaded", + info.num_row(), info.num_col()); + if (fname != NULL) { + printf(" from %s\n", fname); + } else { + printf("\n"); + } + if (info.group_ptr.size() != 0) { + printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1); + } + } + } + /*! \brief save a DataMatrix as DMatrixPage*/ + inline static void Save(const char* fname, const DataMatrix &mat, bool silent) { + utils::FileStream fs(utils::FopenCheck(fname, "wb")); + int magic = kMagic; + fs.Write(&magic, sizeof(magic)); + mat.info.SaveBinary(fs); + ThreadRowPageIterator::Save(mat.fmat()->RowIterator(), fs); + fs.Close(); + if (!silent) { + printf("DMatrixPage: %lux%lu is saved to %s\n", + mat.info.num_row(), mat.info.num_col(), fname); + } + } + /*! \brief the real fmatrix */ + FMatrixS *fmat_; + /*! \brief row iterator */ + ThreadRowPageIterator *iter_; + /*! \brief magic number used to identify DMatrix */ + static const int kMagic = 0xffffab02; +}; } // namespace io } // namespace xgboost #endif // XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_ diff --git a/src/utils/io.h b/src/utils/io.h index 23fa0d468..dbfcee3f6 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -100,7 +100,9 @@ class ISeekStream: public IStream { /*! \brief implementation of file i/o stream */ class FileStream : public ISeekStream { public: - explicit FileStream(void) {} + explicit FileStream(void) { + this->fp = NULL; + } explicit FileStream(FILE *fp) { this->fp = fp; } @@ -117,7 +119,9 @@ class FileStream : public ISeekStream { return static_cast(ftell(fp)); } inline void Close(void) { - fclose(fp); + if (fp != NULL){ + fclose(fp); fp = NULL; + } } private: From 7bc1c3ee79e31afebafe149f96a456d3a2f0ec82 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 29 Aug 2014 20:54:24 -0700 Subject: [PATCH 05/19] various fix of page --- src/io/io.cpp | 7 ++++--- src/io/page_dmatrix-inl.hpp | 12 ++++++++---- src/io/simple_dmatrix-inl.hpp | 2 +- src/io/simple_fmatrix-inl.hpp | 2 +- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/io/io.cpp b/src/io/io.cpp index e8b9ce337..e413b2799 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -37,8 +37,7 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { } void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) { - if (!strcmp(fname + strlen(fname) - 5, ".page")) { - + if (!strcmp(fname + strlen(fname) - 5, ".page")) { DMatrixPage::Save(fname, dmat, silent); return; } @@ -46,7 +45,9 @@ void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) { const DMatrixSimple *p_dmat = static_cast(&dmat); p_dmat->SaveBinary(fname, silent); } else { - utils::Error("not implemented"); + DMatrixSimple smat; + smat.CopyFrom(dmat); + smat.SaveBinary(fname, silent); } } diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 4d13a0bc5..2701fb3b3 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -32,7 +32,7 @@ struct RowBatchPage { const size_t dsize = row.length * sizeof(RowBatch::Entry); if (FreeBytes() < dsize+ sizeof(int)) return false; row_ptr(Size() + 1) = row_ptr(Size()) + row.length; - memcpy(data_ptr(Size()) , row.data, dsize); + memcpy(data_ptr(row_ptr(Size())) , row.data, dsize); ++ data_[0]; return true; } @@ -48,13 +48,18 @@ struct RowBatchPage { batch.data_ptr = this->data_ptr(0); batch.size = static_cast(this->Size()); std::vector &rptr = *p_rptr; - rptr.resize(this->Size()+1); + rptr.resize(this->Size() + 1); for (size_t i = 0; i < rptr.size(); ++i) { rptr[i] = static_cast(this->row_ptr(i)); } batch.ind_ptr = &rptr[0]; return batch; } + /*! \brief get i-th row from the batch */ + inline RowBatch::Inst operator[](size_t i) { + return RowBatch::Inst(data_ptr(0) + row_ptr(i), + static_cast(row_ptr(i+1) - row_ptr(i))); + } /*! * \brief clear the page, cleanup the content */ @@ -77,7 +82,7 @@ struct RowBatchPage { return data_[0]; } /*! \brief page size 64 MB */ - static const size_t kPageSize = 64 << 18; + static const size_t kPageSize = 64 << 8; private: /*! \return number of elements */ @@ -112,7 +117,6 @@ class ThreadRowPageIterator: public utils::IIterator { itr.BeforeFirst(); isend_ = false; base_rowid_ = 0; - utils::Assert(this->LoadNextPage(), "ThreadRowPageIterator"); } virtual bool Next(void) { if(!this->LoadNextPage()) return false; diff --git a/src/io/simple_dmatrix-inl.hpp b/src/io/simple_dmatrix-inl.hpp index 47be8a41a..8d7064bdd 100644 --- a/src/io/simple_dmatrix-inl.hpp +++ b/src/io/simple_dmatrix-inl.hpp @@ -44,8 +44,8 @@ class DMatrixSimple : public DataMatrix { } /*! \brief copy content data from source matrix */ inline void CopyFrom(const DataMatrix &src) { - this->info = src.info; this->Clear(); + this->info = src.info; // clone data content in thos matrix utils::IIterator *iter = src.fmat()->RowIterator(); iter->BeforeFirst(); diff --git a/src/io/simple_fmatrix-inl.hpp b/src/io/simple_fmatrix-inl.hpp index 86763a105..f099eb1a9 100644 --- a/src/io/simple_fmatrix-inl.hpp +++ b/src/io/simple_fmatrix-inl.hpp @@ -150,7 +150,7 @@ class FMatrixS : public IFMatrix{ iter_->BeforeFirst(); while (iter_->Next()) { const RowBatch &batch = iter_->Value(); - for (size_t i = 0; i < batch.size; ++i) { + for (size_t i = 0; i < batch.size; ++i) { if (pkeep == 1.0f || random::SampleBinary(pkeep)) { buffered_rowset_.push_back(static_cast(batch.base_rowid+i)); RowBatch::Inst inst = batch[i]; From 9830674b75c954a6ca02546d60367fae9be1e6d9 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 29 Aug 2014 21:04:40 -0700 Subject: [PATCH 06/19] seems page is ok, try add col tmr --- src/io/page_dmatrix-inl.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 2701fb3b3..e16beb4b6 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -82,7 +82,7 @@ struct RowBatchPage { return data_[0]; } /*! \brief page size 64 MB */ - static const size_t kPageSize = 64 << 8; + static const size_t kPageSize = 64 << 18; private: /*! \return number of elements */ @@ -104,7 +104,7 @@ struct RowBatchPage { class ThreadRowPageIterator: public utils::IIterator { public: ThreadRowPageIterator(void) { - itr.SetParam("buffer_size", "4"); + itr.SetParam("buffer_size", "2"); page_ = NULL; base_rowid_ = 0; isend_ = false; From 366ac95ad331fe68970c39cb74e4fad95cde6045 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 29 Aug 2014 21:27:03 -0700 Subject: [PATCH 07/19] windows check --- .gitignore | 2 +- R-package/src/Makevars | 5 +++-- R-package/src/Makevars.win | 5 +++-- src/io/page_dmatrix-inl.hpp | 6 +++--- src/utils/io.h | 10 +++++----- wrapper/xgboost.py | 4 ++-- 6 files changed, 17 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index f1f9400ab..4551c79cc 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ *.slo *.lo *.o - +*.page # Compiled Dynamic libraries *.so *.dylib diff --git a/R-package/src/Makevars b/R-package/src/Makevars index 7dfda4d57..b0d3283b9 100644 --- a/R-package/src/Makevars +++ b/R-package/src/Makevars @@ -5,9 +5,10 @@ CXX=`R CMD config CXX` CFLAGS=`R CMD config CFLAGS` # expose these flags to R CMD SHLIB PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_ERROR_ -I$(PKGROOT) $(SHLIB_OPENMP_CFLAGS) -XGBFLAG= $(CFLAGS) -DXGBOOST_CUSTOMIZE_ERROR_ -fPIC $(SHLIB_OPENMP_CFLAGS) +PKG_CPPFLAGS+= $(SHLIB_PTHREAD_FLAGS) +XGBFLAG= $(CFLAGS) -DXGBOOST_CUSTOMIZE_ERROR_ -fPIC $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS) -PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) +PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS) ifeq ($(no_omp),1) PKG_CPPFLAGS += -DDISABLE_OPENMP diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 3df9891fc..8f5f7ed98 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -5,8 +5,9 @@ CXX=`Rcmd config CXX` CFLAGS=`Rcmd config CFLAGS` # expose these flags to R CMD SHLIB PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_ERROR_ -I$(PKGROOT) $(SHLIB_OPENMP_CFLAGS) -XGBFLAG= $(CFLAGS) -DXGBOOST_CUSTOMIZE_ERROR_ -fPIC $(SHLIB_OPENMP_CFLAGS) -PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) +PKG_CPPFLAGS+= $(SHLIB_PTHREAD_FLAGS) +XGBFLAG= $(CFLAGS) -DXGBOOST_CUSTOMIZE_ERROR_ -fPIC $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS) +PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) $(SHLIB_PTHREAD_FLAGS) ifeq ($(no_omp),1) PKG_CPPFLAGS += -DDISABLE_OPENMP diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index e16beb4b6..8db944c85 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -50,13 +50,13 @@ struct RowBatchPage { std::vector &rptr = *p_rptr; rptr.resize(this->Size() + 1); for (size_t i = 0; i < rptr.size(); ++i) { - rptr[i] = static_cast(this->row_ptr(i)); + rptr[i] = static_cast(this->row_ptr(static_cast(i))); } batch.ind_ptr = &rptr[0]; return batch; } /*! \brief get i-th row from the batch */ - inline RowBatch::Inst operator[](size_t i) { + inline RowBatch::Inst operator[](int i) { return RowBatch::Inst(data_ptr(0) + row_ptr(i), static_cast(row_ptr(i+1) - row_ptr(i))); } @@ -173,7 +173,7 @@ class ThreadRowPageIterator: public utils::IIterator { // loader factory for page struct Factory { public: - size_t file_begin_; + long file_begin_; utils::FileStream fi; Factory(void) {} inline void SetFile(const utils::FileStream &fi) { diff --git a/src/utils/io.h b/src/utils/io.h index dbfcee3f6..276dd7312 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -92,9 +92,9 @@ class IStream { class ISeekStream: public IStream { public: /*! \brief seek to certain position of the file */ - virtual void Seek(size_t pos) = 0; + virtual void Seek(long pos) = 0; /*! \brief tell the position of the stream */ - virtual size_t Tell(void) = 0; + virtual long Tell(void) = 0; }; /*! \brief implementation of file i/o stream */ @@ -112,11 +112,11 @@ class FileStream : public ISeekStream { virtual void Write(const void *ptr, size_t size) { fwrite(ptr, size, 1, fp); } - virtual void Seek(size_t pos) { + virtual void Seek(long pos) { fseek(fp, pos, SEEK_SET); } - virtual size_t Tell(void) { - return static_cast(ftell(fp)); + virtual long Tell(void) { + return ftell(fp); } inline void Close(void) { if (fp != NULL){ diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index adf59c829..e4338e0cd 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -56,7 +56,7 @@ class DMatrix: weight for each instances """ # force into void_p, mac need to pass things in as void_p - if data == None: + if data is None: self.handle = None return if isinstance(data, str): @@ -484,7 +484,7 @@ def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None feval: """ bst = Booster(params, [dtrain]+[ d[0] for d in evals ] ) - if obj == None: + if obj is None: for i in range(num_boost_round): bst.update( dtrain, i ) if len(evals) != 0: From 0a7cfb32c6f453482bdc8a8f5fa4f2e2db1308fc Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 31 Aug 2014 21:58:01 -0700 Subject: [PATCH 08/19] add fmatrix, fight tmr --- src/io/page_dmatrix-inl.hpp | 1 + src/io/page_fmatrix-inl.hpp | 75 +++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 src/io/page_fmatrix-inl.hpp diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 8db944c85..23013b98b 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -9,6 +9,7 @@ #include "../utils/iterator.h" #include "../utils/thread_buffer.h" #include "./simple_fmatrix-inl.hpp" +#include "./page_fmatrix-inl.hpp" namespace xgboost { namespace io { diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp new file mode 100644 index 000000000..156cddb63 --- /dev/null +++ b/src/io/page_fmatrix-inl.hpp @@ -0,0 +1,75 @@ +#ifndef XGBOOST_IO_PAGE_FMATRIX_INL_HPP_ +#define XGBOOST_IO_PAGE_FMATRIX_INL_HPP_ +/*! + * \file page_fmatrix-inl.hpp + * sparse page manager for fmatrix + * \author Tianqi Chen + */ +#include "../data.h" +#include "../utils/iterator.h" +#include "../utils/thread_buffer.h" +namespace xgboost { +namespace io { + +class CSCMatrixManager { + public: + /*! \brief in memory page */ + struct Page { + public: + /*! \brief initialize the page */ + inline void Init(size_t size) { + buffer.resize(size); + num_entry = 0; + col_index.clear(); + col_data.clear(); + } + /*! \brief number of used entries */ + size_t num_entry; + /*! \brief column index */ + std::vector col_index; + /*! \brief column data */ + std::vector col_data; + /*! \brief number of free entries */ + inline size_t NumFreeEntry(void) const { + return buffer.size() - num_entry; + } + inline ColBatch::Entry* AllocEntry(size_t len) { + ColBatch::Entry *p_data = &buffer[0] + num_entry; + num_entry += len; + return p_data; + } + /*! \brief get underlying batch */ + inline ColBatch GetBatch(void) const { + ColBatch batch; + batch.col_index = &col_index[0]; + batch.col_data = &col_data[0]; + return batch; + } + private: + /*! \brief buffer space, not to be changed since ready */ + std::vector buffer; + }; + + private: + /*! \brief fill a page with */ + inline bool Fill(size_t cidx, Page *p_page) { + size_t len = col_ptr_[cidx+1] - col_ptr_[cidx]; + if (p_page->NumFreeEntry() < len) return false; + ColBatch::Entry *p_data = p_page->AllocEntry(len); + fi->Seek(col_ptr_[cidx]); + utils::Check(fi->Read(p_data, sizeof(ColBatch::Entry) * len) != 0, + "invalid column buffer format"); + p_page->col_data.push_back(ColBatch::Inst(p_data, len)); + p_page->col_index.push_back(cidx); + } + /*! \brief size of data content */ + size_t data_size_; + /*! \brief input stream */ + utils::ISeekStream *fi; + /*! \brief column pointer of CSC format */ + std::vector col_ptr_; +}; + +} // namespace io +} // namespace xgboost +#endif // XGBOOST_IO_PAGE_FMATRIX_INL_HPP_ From e3153b976c589ce0d49c555188809b7c2ab160da Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 31 Aug 2014 22:25:30 -0700 Subject: [PATCH 09/19] chgs --- src/io/page_fmatrix-inl.hpp | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index 156cddb63..f077f0dde 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -17,12 +17,17 @@ class CSCMatrixManager { struct Page { public: /*! \brief initialize the page */ - inline void Init(size_t size) { + explicit Page(size_t size) { buffer.resize(size); + col_index.reserve(10); + col_data.reserved(10); + } + /*! \brief clear the page */ + inline void Clear(void) { num_entry = 0; col_index.clear(); col_data.clear(); - } + } /*! \brief number of used entries */ size_t num_entry; /*! \brief column index */ @@ -49,6 +54,33 @@ class CSCMatrixManager { /*! \brief buffer space, not to be changed since ready */ std::vector buffer; }; + /*! \brief define type of page pointer */ + typedef Page *PagePtr; + /*! \brief get column pointer */ + const std::vector &col_ptr(void) const { + return col_ptr_; + } + inline bool Init(void) { + return true; + } + inline void SetParam(const char *name, const char *val) { + } + inline bool LoadNext(PagePtr &val) { + + } + inline PagePtr Create(void) { + PagePtr a = new Page(); + return a; + } + inline void FreeSpace(PagePtr &a) { + delete a; + } + inline void Destroy(void) { + fi.Close(); + } + inline void BeforeFirst(void) { + fi.Seek(file_begin_); + } private: /*! \brief fill a page with */ From 7d1e9f06d43b2601963e91ebb6fa871c0dbf7426 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 1 Sep 2014 10:45:05 -0700 Subject: [PATCH 10/19] add fmatrix in, todo add buffer file --- src/io/page_dmatrix-inl.hpp | 20 +-- src/io/page_fmatrix-inl.hpp | 248 +++++++++++++++++++++++++++++++++--- src/utils/thread_buffer.h | 3 + 3 files changed, 237 insertions(+), 34 deletions(-) diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 23013b98b..01b7f8fc7 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -108,7 +108,6 @@ class ThreadRowPageIterator: public utils::IIterator { itr.SetParam("buffer_size", "2"); page_ = NULL; base_rowid_ = 0; - isend_ = false; } virtual ~ThreadRowPageIterator(void) { } @@ -116,11 +115,10 @@ class ThreadRowPageIterator: public utils::IIterator { } virtual void BeforeFirst(void) { itr.BeforeFirst(); - isend_ = false; base_rowid_ = 0; } virtual bool Next(void) { - if(!this->LoadNextPage()) return false; + if(!itr.Next(page_)) return false; out_ = page_->GetRowBatch(&tmp_ptr_, base_rowid_); base_rowid_ += out_.size; return true; @@ -154,21 +152,12 @@ class ThreadRowPageIterator: public utils::IIterator { if (page.Size() != 0) page.Save(fo); } private: - // load in next page - inline bool LoadNextPage(void) { - ptop_ = 0; - bool ret = itr.Next(page_); - isend_ = !ret; - return ret; - } // base row id size_t base_rowid_; // temporal ptr std::vector tmp_ptr_; // output data RowBatch out_; - // whether we reach end of file - bool isend_; // page pointer type typedef RowBatchPage* PagePtr; // loader factory for page @@ -205,7 +194,6 @@ class ThreadRowPageIterator: public utils::IIterator { protected: PagePtr page_; - int ptop_; utils::ThreadBuffer itr; }; @@ -234,7 +222,8 @@ class DMatrixPage : public DataMatrix { iter_->Load(fi); if (!silent) { printf("DMatrixPage: %lux%lu matrix is loaded", - info.num_row(), info.num_col()); + static_cast(info.num_row()), + static_cast(info.num_col())); if (fname != NULL) { printf(" from %s\n", fname); } else { @@ -255,7 +244,8 @@ class DMatrixPage : public DataMatrix { fs.Close(); if (!silent) { printf("DMatrixPage: %lux%lu is saved to %s\n", - mat.info.num_row(), mat.info.num_col(), fname); + static_cast(mat.info.num_row()), + static_cast(mat.info.num_col()), fname); } } /*! \brief the real fmatrix */ diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index f077f0dde..cf4923b7b 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -20,7 +20,7 @@ class CSCMatrixManager { explicit Page(size_t size) { buffer.resize(size); col_index.reserve(10); - col_data.reserved(10); + col_data.reserve(10); } /*! \brief clear the page */ inline void Clear(void) { @@ -57,49 +57,259 @@ class CSCMatrixManager { /*! \brief define type of page pointer */ typedef Page *PagePtr; /*! \brief get column pointer */ - const std::vector &col_ptr(void) const { + inline const std::vector &col_ptr(void) const { return col_ptr_; } - inline bool Init(void) { - return true; - } inline void SetParam(const char *name, const char *val) { - } - inline bool LoadNext(PagePtr &val) { - } inline PagePtr Create(void) { - PagePtr a = new Page(); - return a; + return new Page(page_size_); } inline void FreeSpace(PagePtr &a) { delete a; } inline void Destroy(void) { - fi.Close(); } inline void BeforeFirst(void) { - fi.Seek(file_begin_); + col_index_ = col_todo_; + read_top_ = 0; + } + inline bool LoadNext(PagePtr &val) { + val->Clear(); + if (read_top_ >= col_index_.size()) return false; + while (read_top_ < col_index_.size()) { + if (!this->TryFill(col_index_[read_top_], val)) return true; + ++read_top_; + } + return true; + } + inline bool Init(void) { + this->BeforeFirst(); + return true; + } + inline void Setup(utils::ISeekStream *fi, double page_ratio) { + fi_ = fi; + fi_->Read(&begin_meta_ , sizeof(size_t)); + fi_->Seek(begin_meta_); + fi_->Read(&col_ptr_); + size_t psmax = 0; + for (size_t i = 0; i < col_ptr_.size() - 1; ++i) { + psmax = std::max(psmax, col_ptr_[i+1] - col_ptr_[i]); + } + utils::Check(page_ratio >= 1.0f, "col_page_ratio must be at least 1"); + page_size_ = std::max(static_cast(psmax * page_ratio), psmax); + } + inline void SetColSet(const std::vector &cset, bool setall) { + if (!setall) { + col_todo_.resize(cset.size()); + for (size_t i = 0; i < cset.size(); ++i) { + col_todo_[i] = cset[i]; + utils::Assert(col_todo_[i] < static_cast(col_ptr_.size() - 1), + "CSCMatrixManager: column index exceed bound"); + } + std::sort(col_todo_.begin(), col_todo_.end()); + } else { + col_todo_.resize(col_ptr_.size()-1); + for (size_t i = 0; i < col_todo_.size(); ++i) { + col_todo_[i] = static_cast(i); + } + } } - private: /*! \brief fill a page with */ - inline bool Fill(size_t cidx, Page *p_page) { + inline bool TryFill(size_t cidx, Page *p_page) { size_t len = col_ptr_[cidx+1] - col_ptr_[cidx]; if (p_page->NumFreeEntry() < len) return false; ColBatch::Entry *p_data = p_page->AllocEntry(len); - fi->Seek(col_ptr_[cidx]); - utils::Check(fi->Read(p_data, sizeof(ColBatch::Entry) * len) != 0, + fi_->Seek(col_ptr_[cidx]); + utils::Check(fi_->Read(p_data, sizeof(ColBatch::Entry) * len) != 0, "invalid column buffer format"); p_page->col_data.push_back(ColBatch::Inst(p_data, len)); p_page->col_index.push_back(cidx); } + // the following are in memory auxiliary data structure + /*! \brief top of reader position */ + size_t read_top_; + /*! \brief size of page */ + size_t page_size_; + /*! \brief column index to be loaded */ + std::vector col_index_; + /*! \brief column index to be after calling before first */ + std::vector col_todo_; + // the following are input content /*! \brief size of data content */ - size_t data_size_; + size_t begin_meta_; /*! \brief input stream */ - utils::ISeekStream *fi; + utils::ISeekStream *fi_; /*! \brief column pointer of CSC format */ - std::vector col_ptr_; + std::vector col_ptr_; +}; + +class ThreadColPageIterator : public utils::IIterator { + public: + ThreadColPageIterator(void) { + itr_.SetParam("buffer_size", "2"); + page_ = NULL; + fi_ = NULL; + silent = 0; + } + virtual ~ThreadColPageIterator(void) { + if (fi_ != NULL) { + fi_->Close(); delete fi_; + } + } + virtual void Init(void) { + fi_ = new utils::FileStream(utils::FopenCheck(col_pagefile_.c_str(), "rb")); + itr_.get_factory().Setup(fi_, col_pageratio_); + if (silent == 0) { + printf("ThreadColPageIterator: finish initialzing from %s, %u columns\n", + col_pagefile_.c_str(), static_cast(col_ptr().size() - 1)); + } + } + virtual void SetParam(const char *name, const char *val) { + if (!strcmp("col_pageratio", val)) col_pageratio_ = atof(val); + if (!strcmp("col_pagefile", val)) col_pagefile_ = val; + if (!strcmp("silent", val)) silent = atoi(val); + } + virtual void BeforeFirst(void) { + itr_.BeforeFirst(); + } + virtual bool Next(void) { + if(!itr_.Next(page_)) return false; + out_ = page_->GetBatch(); + return true; + } + virtual const ColBatch &Value(void) const{ + return out_; + } + inline const std::vector &col_ptr(void) const { + return itr_.get_factory().col_ptr(); + } + inline void SetColSet(const std::vector &cset, bool setall = false) { + itr_.get_factory().SetColSet(cset, setall); + } + + private: + // shutup + int silent; + // input file + utils::FileStream *fi_; + // size of page + float col_pageratio_; + // name of file + std::string col_pagefile_; + // output data + ColBatch out_; + // page to be loaded + CSCMatrixManager::PagePtr page_; + // internal iterator + utils::ThreadBuffer itr_; +}; + +/*! + * \brief sparse matrix that support column access + */ +class FMatrixPage : public IFMatrix { + public: + /*! \brief constructor */ + FMatrixPage(utils::IIterator *iter) { + this->row_iter_ = iter; + this->col_iter_ = NULL; + } + // destructor + virtual ~FMatrixPage(void) { + if (row_iter_ != NULL) delete row_iter_; + if (col_iter_ != NULL) delete col_iter_; + } + /*! \return whether column access is enabled */ + virtual bool HaveColAccess(void) const { + return col_iter_ != NULL; + } + /*! \brief get number of colmuns */ + virtual size_t NumCol(void) const { + utils::Check(this->HaveColAccess(), "NumCol:need column access"); + return col_iter_->col_ptr().size() - 1; + } + /*! \brief get number of buffered rows */ + virtual const std::vector &buffered_rowset(void) const { + return buffered_rowset_; + } + /*! \brief get column size */ + virtual size_t GetColSize(size_t cidx) const { + const std::vector &col_ptr = col_iter_->col_ptr(); + return col_ptr[cidx+1] - col_ptr[cidx]; + } + /*! \brief get column density */ + virtual float GetColDensity(size_t cidx) const { + const std::vector &col_ptr = col_iter_->col_ptr(); + size_t nmiss = buffered_rowset_.size() - (col_ptr[cidx+1] - col_ptr[cidx]); + return 1.0f - (static_cast(nmiss)) / buffered_rowset_.size(); + } + virtual void InitColAccess(float pkeep = 1.0f) { + if (this->HaveColAccess()) return; + this->InitColData(pkeep); + } + /*! + * \brief get the row iterator associated with FMatrix + */ + virtual utils::IIterator* RowIterator(void) { + row_iter_->BeforeFirst(); + return row_iter_; + } + /*! + * \brief get the column based iterator + */ + virtual utils::IIterator* ColIterator(void) { + std::vector cset; + col_iter_->SetColSet(cset, true); + col_iter_->BeforeFirst(); + return col_iter_; + } + /*! + * \brief colmun based iterator + */ + virtual utils::IIterator *ColIterator(const std::vector &fset) { + col_iter_->SetColSet(fset, false); + col_iter_->BeforeFirst(); + return col_iter_; + } + + protected: + /*! + * \brief intialize column data + * \param pkeep probability to keep a row + */ + inline void InitColData(float pkeep) { + buffered_rowset_.clear(); + // start working + row_iter_->BeforeFirst(); + while (row_iter_->Next()) { + const RowBatch &batch = row_iter_->Value(); + for (size_t i = 0; i < batch.size; ++i) { + } + } + row_iter_->BeforeFirst(); + size_t ktop = 0; + while (row_iter_->Next()) { + const RowBatch &batch = row_iter_->Value(); + for (size_t i = 0; i < batch.size; ++i) { + if (ktop < buffered_rowset_.size() && + buffered_rowset_[ktop] == batch.base_rowid+i) { + ++ktop; + // TODO1 + } + } + } + // sort columns + } + + private: + // row iterator + utils::IIterator *row_iter_; + // column iterator + ThreadColPageIterator *col_iter_; + /*! \brief list of row index that are buffered */ + std::vector buffered_rowset_; }; } // namespace io diff --git a/src/utils/thread_buffer.h b/src/utils/thread_buffer.h index fa488a220..ace50c4b8 100644 --- a/src/utils/thread_buffer.h +++ b/src/utils/thread_buffer.h @@ -113,6 +113,9 @@ class ThreadBuffer { inline ElemFactory &get_factory(void) { return factory; } + inline const ElemFactory &get_factory(void) const{ + return factory; + } // size of buffer int buf_size; private: From 9d3e09ff2a3fded1dbd1204ddb0f0722955ff24d Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 1 Sep 2014 20:44:15 -0700 Subject: [PATCH 11/19] make rowbatch page flexible --- src/io/page_dmatrix-inl.hpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 01b7f8fc7..76767d942 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -16,7 +16,7 @@ namespace io { /*! \brief page structure that can be used to store a rowbatch */ struct RowBatchPage { public: - RowBatchPage(void) { + RowBatchPage(size_t page_size) : kPageSize(page_size) { data_ = new int[kPageSize]; utils::Assert(data_ != NULL, "fail to allocate row batch page"); this->Clear(); @@ -82,8 +82,6 @@ struct RowBatchPage { inline int Size(void) const { return data_[0]; } - /*! \brief page size 64 MB */ - static const size_t kPageSize = 64 << 18; private: /*! \return number of elements */ @@ -98,6 +96,8 @@ struct RowBatchPage { inline RowBatch::Entry* data_ptr(int i) { return (RowBatch::Entry*)(&data_[1]) + i; } + // page size + const size_t kPageSize; // content of data int *data_; }; @@ -137,7 +137,7 @@ class ThreadRowPageIterator: public utils::IIterator { */ inline static void Save(utils::IIterator *iter, utils::IStream &fo) { - RowBatchPage page; + RowBatchPage page(kPageSize); iter->BeforeFirst(); while (iter->Next()) { const RowBatch &batch = iter->Value(); @@ -151,6 +151,8 @@ class ThreadRowPageIterator: public utils::IIterator { } if (page.Size() != 0) page.Save(fo); } + /*! \brief page size 64 MB */ + static const size_t kPageSize = 64 << 18; private: // base row id size_t base_rowid_; @@ -178,7 +180,7 @@ class ThreadRowPageIterator: public utils::IIterator { return val->Load(fi); } inline PagePtr Create(void) { - PagePtr a = new RowBatchPage(); + PagePtr a = new RowBatchPage(kPageSize); return a; } inline void FreeSpace(PagePtr &a) { From e43bb9118541c3ad15adba958b2d7b3d5a885087 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 1 Sep 2014 21:30:03 -0700 Subject: [PATCH 12/19] add matrix builder --- src/io/page_fmatrix-inl.hpp | 7 +-- src/utils/matrix_csr.h | 112 ++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 4 deletions(-) diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index cf4923b7b..b2ce76faf 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -121,7 +121,7 @@ class CSCMatrixManager { size_t len = col_ptr_[cidx+1] - col_ptr_[cidx]; if (p_page->NumFreeEntry() < len) return false; ColBatch::Entry *p_data = p_page->AllocEntry(len); - fi_->Seek(col_ptr_[cidx]); + fi_->Seek(col_ptr_[cidx] * sizeof(ColBatch::Entry) + sizeof(size_t)); utils::Check(fi_->Read(p_data, sizeof(ColBatch::Entry) * len) != 0, "invalid column buffer format"); p_page->col_data.push_back(ColBatch::Inst(p_data, len)); @@ -285,8 +285,7 @@ class FMatrixPage : public IFMatrix { row_iter_->BeforeFirst(); while (row_iter_->Next()) { const RowBatch &batch = row_iter_->Value(); - for (size_t i = 0; i < batch.size; ++i) { - } + } row_iter_->BeforeFirst(); size_t ktop = 0; @@ -294,7 +293,7 @@ class FMatrixPage : public IFMatrix { const RowBatch &batch = row_iter_->Value(); for (size_t i = 0; i < batch.size; ++i) { if (ktop < buffered_rowset_.size() && - buffered_rowset_[ktop] == batch.base_rowid+i) { + buffered_rowset_[ktop] == batch.base_rowid + i) { ++ktop; // TODO1 } diff --git a/src/utils/matrix_csr.h b/src/utils/matrix_csr.h index 0f3b20a14..44a3b8818 100644 --- a/src/utils/matrix_csr.h +++ b/src/utils/matrix_csr.h @@ -7,6 +7,7 @@ */ #include #include +#include "./io.h" #include "./utils.h" namespace xgboost { @@ -118,6 +119,117 @@ struct SparseCSRMBuilder { } }; +/*! + * \brief a class used to help construct CSR format matrix file + * \tparam IndexType type of index used to store the index position + * \tparam SizeType type of size used in row pointer + */ +template +struct SparseCSRFileBuilder { + public: + explicit SparseCSRFileBuilder(utils::ISeekStream *fo, size_t buffer_size) + : fo(fo), buffer_size(buffer_size) { + } + /*! + * \brief step 1: initialize the number of rows in the data, not necessary exact + * \nrows number of rows in the matrix, can be smaller than expected + */ + inline void InitBudget(size_t nrows = 0) { + rptr.clear(); + rptr.resize(nrows + 1, 0); + } + /*! + * \brief step 2: add budget to each rows + * \param row_id the id of the row + * \param nelem number of element budget add to this row + */ + inline void AddBudget(size_t row_id, SizeType nelem = 1) { + if (rptr.size() < row_id + 2) { + rptr.resize(row_id + 2, 0); + } + rptr[row_id + 1] += nelem; + } + /*! \brief step 3: initialize the necessary storage */ + inline void InitStorage(void) { + SizeType nelem = 0; + for (size_t i = 1; i < rptr.size(); i++) { + nelem += rptr[i]; + rptr[i] = nelem; + } + SizeType begin_meta = sizeof(SizeType) + nelem * sizeof(IndexType); + fo->Seek(0); + fo->Write(&begin_meta, sizeof(begin_meta)); + fo->Seek(begin_meta); + fo->Write(rptr); + // setup buffer space + buffer_rptr.resize(rptr.size()); + buffer.reserve(buffer_size); + buffer_data.resize(buffer_size); + saved_offset.clear(); + saved_offset.resize(rptr.size() - 1, 0); + this->ClearBuffer(); + } + /*! \brief step 4: push element into buffer */ + inline void PushElem(SizeType row_id, IndexType col_id) { + if (buffer_temp.size() == buffer_size) { + this->WriteBuffer(); + this->ClearBuffer(); + } + buffer_temp.push_back(std::make_pair(row_id, col_id)); + } + /*! \brief finalize the construction */ + inline void Finalize(void) { + this->WriteBuffer(); + for (size_t i = 0; i < saved_offset.size(); ++i) { + utils::Assert(saved_offset[i] == rptr[i+1], "some block not write out"); + } + } + + protected: + inline void WriteBuffer(void) { + SizeType start = 0; + for (size_t i = 1; i < buffer_rptr.size(); ++i) { + size_t rlen = buffer_rptr[i]; + buffer_rptr[i] = start; + start += rlen; + } + for (size_t i = 0; i < buffer_temp.size(); ++i) { + SizeType &rp = buffer_rptr[buffer_temp[i].first + 1]; + buffer_data[rp++] = buffer_temp[i].second; + } + // write out + for (size_t i = 0; i < buffer_rptr.size(); ++i) { + size_t nelem = buffer_rptr[i+1] - buffer_rptr[i]; + if (nelem != 0) { + utils::Assert(saved_offset[i] < rptr[i+1], "data exceed bound"); + fo->Seek((rptr[i] + saved_offset[i]) * sizeof(IndexType) + sizeof(SizeType)); + fo->Write(&buffer_data[0] + buffer_rptr[i], nelem * sizeof(IndexType)); + saved_offset[i] += nelem; + } + } + } + inline void ClearBuffer(void) { + buffer_temp.clear(); + std::fill(buffer_rptr.begin(), buffer_rptr.end(), 0); + } + private: + /*! \brief output file pointer the data */ + utils::ISeekStream *fo; + /*! \brief pointer to each of the row */ + std::vector rptr; + /*! \brief saved top space of each item */ + std::vector saved_offset; + // ----- the following are buffer space + /*! \brief maximum size of content buffer*/ + size_t buffer_size; + /*! \brief store the data content */ + std::vector< std::pair > buffer_temp; + /*! \brief saved top space of each item */ + std::vector buffer_rptr; + /*! \brief saved top space of each item */ + std::vector buffer_data; +}; + } // namespace utils } // namespace xgboost #endif From 4b9aeea89c4c6fba24a8e0d487df65babb67392f Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 2 Sep 2014 13:14:54 -0700 Subject: [PATCH 13/19] finish the fmatrix --- src/io/page_dmatrix-inl.hpp | 18 +++--- src/io/page_fmatrix-inl.hpp | 108 +++++++++++++++++++++++------------- src/utils/io.h | 1 - src/utils/matrix_csr.h | 34 ++++++++++-- 4 files changed, 106 insertions(+), 55 deletions(-) diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 76767d942..83c745599 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -223,16 +223,16 @@ class DMatrixPage : public DataMatrix { this->info.LoadBinary(fi); iter_->Load(fi); if (!silent) { - printf("DMatrixPage: %lux%lu matrix is loaded", - static_cast(info.num_row()), - static_cast(info.num_col())); + utils::Printf("DMatrixPage: %lux%lu matrix is loaded", + static_cast(info.num_row()), + static_cast(info.num_col())); if (fname != NULL) { - printf(" from %s\n", fname); + utils::Printf(" from %s\n", fname); } else { - printf("\n"); + utils::Printf("\n"); } if (info.group_ptr.size() != 0) { - printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1); + utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1); } } } @@ -245,9 +245,9 @@ class DMatrixPage : public DataMatrix { ThreadRowPageIterator::Save(mat.fmat()->RowIterator(), fs); fs.Close(); if (!silent) { - printf("DMatrixPage: %lux%lu is saved to %s\n", - static_cast(mat.info.num_row()), - static_cast(mat.info.num_col()), fname); + utils::Printf("DMatrixPage: %lux%lu is saved to %s\n", + static_cast(mat.info.num_row()), + static_cast(mat.info.num_col()), fname); } } /*! \brief the real fmatrix */ diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index b2ce76faf..7e9903be4 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -7,10 +7,11 @@ */ #include "../data.h" #include "../utils/iterator.h" +#include "../utils/io.h" +#include "../utils/matrix_csr.h" #include "../utils/thread_buffer.h" namespace xgboost { namespace io { - class CSCMatrixManager { public: /*! \brief in memory page */ @@ -56,6 +57,10 @@ class CSCMatrixManager { }; /*! \brief define type of page pointer */ typedef Page *PagePtr; + // constructor + CSCMatrixManager(void) { + fi_ = NULL; + } /*! \brief get column pointer */ inline const std::vector &col_ptr(void) const { return col_ptr_; @@ -89,7 +94,8 @@ class CSCMatrixManager { } inline void Setup(utils::ISeekStream *fi, double page_ratio) { fi_ = fi; - fi_->Read(&begin_meta_ , sizeof(size_t)); + fi_->Read(&begin_meta_ , sizeof(begin_meta_)); + begin_data_ = static_cast(fi->Tell()); fi_->Seek(begin_meta_); fi_->Read(&col_ptr_); size_t psmax = 0; @@ -121,7 +127,7 @@ class CSCMatrixManager { size_t len = col_ptr_[cidx+1] - col_ptr_[cidx]; if (p_page->NumFreeEntry() < len) return false; ColBatch::Entry *p_data = p_page->AllocEntry(len); - fi_->Seek(col_ptr_[cidx] * sizeof(ColBatch::Entry) + sizeof(size_t)); + fi_->Seek(col_ptr_[cidx] * sizeof(ColBatch::Entry) + begin_data_); utils::Check(fi_->Read(p_data, sizeof(ColBatch::Entry) * len) != 0, "invalid column buffer format"); p_page->col_data.push_back(ColBatch::Inst(p_data, len)); @@ -137,6 +143,8 @@ class CSCMatrixManager { /*! \brief column index to be after calling before first */ std::vector col_todo_; // the following are input content + /*! \brief beginning position of data content */ + size_t begin_data_; /*! \brief size of data content */ size_t begin_meta_; /*! \brief input stream */ @@ -147,36 +155,25 @@ class CSCMatrixManager { class ThreadColPageIterator : public utils::IIterator { public: - ThreadColPageIterator(void) { + explicit ThreadColPageIterator(utils::ISeekStream *fi, + float page_ratio, bool silent) { itr_.SetParam("buffer_size", "2"); - page_ = NULL; - fi_ = NULL; - silent = 0; + itr_.get_factory().Setup(fi, page_ratio); + if (!silent) { + utils::Printf("ThreadColPageIterator: finish initialzing, %u columns\n", + static_cast(col_ptr().size() - 1)); + } } virtual ~ThreadColPageIterator(void) { - if (fi_ != NULL) { - fi_->Close(); delete fi_; - } - } - virtual void Init(void) { - fi_ = new utils::FileStream(utils::FopenCheck(col_pagefile_.c_str(), "rb")); - itr_.get_factory().Setup(fi_, col_pageratio_); - if (silent == 0) { - printf("ThreadColPageIterator: finish initialzing from %s, %u columns\n", - col_pagefile_.c_str(), static_cast(col_ptr().size() - 1)); - } - } - virtual void SetParam(const char *name, const char *val) { - if (!strcmp("col_pageratio", val)) col_pageratio_ = atof(val); - if (!strcmp("col_pagefile", val)) col_pagefile_ = val; - if (!strcmp("silent", val)) silent = atoi(val); } virtual void BeforeFirst(void) { itr_.BeforeFirst(); } virtual bool Next(void) { - if(!itr_.Next(page_)) return false; - out_ = page_->GetBatch(); + // page to be loaded + CSCMatrixManager::PagePtr page; + if(!itr_.Next(page)) return false; + out_ = page->GetBatch(); return true; } virtual const ColBatch &Value(void) const{ @@ -190,18 +187,8 @@ class ThreadColPageIterator : public utils::IIterator { } private: - // shutup - int silent; - // input file - utils::FileStream *fi_; - // size of page - float col_pageratio_; - // name of file - std::string col_pagefile_; // output data ColBatch out_; - // page to be loaded - CSCMatrixManager::PagePtr page_; // internal iterator utils::ThreadBuffer itr_; }; @@ -212,14 +199,18 @@ class ThreadColPageIterator : public utils::IIterator { class FMatrixPage : public IFMatrix { public: /*! \brief constructor */ - FMatrixPage(utils::IIterator *iter) { + FMatrixPage(utils::IIterator *iter, std::string fname_buffer) { this->row_iter_ = iter; this->col_iter_ = NULL; + this->fi_ = NULL; } // destructor virtual ~FMatrixPage(void) { if (row_iter_ != NULL) delete row_iter_; if (col_iter_ != NULL) delete col_iter_; + if (fi_ != NULL) { + fi_->Close(); delete fi_; + } } /*! \return whether column access is enabled */ virtual bool HaveColAccess(void) const { @@ -275,18 +266,44 @@ class FMatrixPage : public IFMatrix { } protected: + /*! + * \brief try load column data from file + */ + inline bool LoadColData(void) { + FILE *fp = fopen64(fname_cbuffer_.c_str(), "rb"); + if (fp == NULL) return false; + fi_ = new utils::FileStream(fp); + static_cast(fi_)->Read(&buffered_rowset_); + col_iter_ = new ThreadColPageIterator(fi_, 2.0f, false); + return true; + } /*! * \brief intialize column data * \param pkeep probability to keep a row */ inline void InitColData(float pkeep) { - buffered_rowset_.clear(); + buffered_rowset_.clear(); + utils::FileStream fo(utils::FopenCheck(fname_cbuffer_.c_str(), "wb+")); + // use 64M buffer + utils::SparseCSRFileBuilder builder(&fo, 64<<20); + // start working row_iter_->BeforeFirst(); while (row_iter_->Next()) { const RowBatch &batch = row_iter_->Value(); - + for (size_t i = 0; i < batch.size; ++i) { + if (pkeep == 1.0f || random::SampleBinary(pkeep)) { + buffered_rowset_.push_back(static_cast(batch.base_rowid+i)); + RowBatch::Inst inst = batch[i]; + for (bst_uint j = 0; j < inst.length; ++j) { + builder.AddBudget(inst[j].index); + } + } + } } + // write buffered rowset + static_cast(&fo)->Write(buffered_rowset_); + builder.InitStorage(); row_iter_->BeforeFirst(); size_t ktop = 0; while (row_iter_->Next()) { @@ -295,11 +312,18 @@ class FMatrixPage : public IFMatrix { if (ktop < buffered_rowset_.size() && buffered_rowset_[ktop] == batch.base_rowid + i) { ++ktop; - // TODO1 + RowBatch::Inst inst = batch[i]; + for (bst_uint j = 0; j < inst.length; ++j) { + builder.PushElem(inst[j].index, + ColBatch::Entry((bst_uint)(batch.base_rowid+i), + inst[j].fvalue)); + } } } } - // sort columns + builder.Finalize(); + builder.SortRows(ColBatch::Entry::CmpValue, 5); + fo.Close(); } private: @@ -307,6 +331,10 @@ class FMatrixPage : public IFMatrix { utils::IIterator *row_iter_; // column iterator ThreadColPageIterator *col_iter_; + // file pointer to data + utils::FileStream *fi_; + // file name of column buffer + std::string fname_cbuffer_; /*! \brief list of row index that are buffered */ std::vector buffered_rowset_; }; diff --git a/src/utils/io.h b/src/utils/io.h index d98b3e4dc..37f489955 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -125,7 +125,6 @@ class FileStream : public ISeekStream { private: FILE *fp; }; - } // namespace utils } // namespace xgboost #endif diff --git a/src/utils/matrix_csr.h b/src/utils/matrix_csr.h index b2768b2ea..e4c410511 100644 --- a/src/utils/matrix_csr.h +++ b/src/utils/matrix_csr.h @@ -9,6 +9,7 @@ #include #include "./io.h" #include "./utils.h" +#include "./omp.h" namespace xgboost { namespace utils { @@ -155,9 +156,9 @@ struct SparseCSRFileBuilder { for (size_t i = 1; i < rptr.size(); i++) { nelem += rptr[i]; rptr[i] = nelem; - } - SizeType begin_meta = sizeof(SizeType) + nelem * sizeof(IndexType); - fo->Seek(0); + } + begin_data = static_cast(fo->Tell()) + sizeof(SizeType); + SizeType begin_meta = begin_data + nelem * sizeof(IndexType); fo->Write(&begin_meta, sizeof(begin_meta)); fo->Seek(begin_meta); fo->Write(rptr); @@ -184,7 +185,28 @@ struct SparseCSRFileBuilder { utils::Assert(saved_offset[i] == rptr[i+1], "some block not write out"); } } - + /*! \brief content must be in wb+ */ + template + inline void SortRows(Comparator comp, size_t step) { + for (size_t i = 0; i < rptr.size() - 1; i += step) { + bst_omp_uint begin = static_cast(i); + bst_omp_uint end = static_cast(std::min(rptr.size(), i + step)); + if (rptr[end] != rptr[begin]) { + fo->Seek(begin_data + rptr[begin] * sizeof(IndexType)); + buffer_data.resize(rptr[end] - rptr[begin]); + fo->Read(BeginPtr(buffer_data), (rptr[end] - rptr[begin]) * sizeof(IndexType)); + // do parallel sorting + #pragma omp parallel for schedule(static) + for (bst_omp_uint j = begin; j < end; ++j){ + std::sort(&buffer_data[0] + rptr[j] - rptr[begin], + &buffer_data[0] + rptr[j+1] - rptr[begin], + comp); + } + fo->Seek(begin_data + rptr[begin] * sizeof(IndexType)); + fo->Write(BeginPtr(buffer_data), (rptr[end] - rptr[begin]) * sizeof(IndexType)); + } + } + } protected: inline void WriteBuffer(void) { SizeType start = 0; @@ -202,7 +224,7 @@ struct SparseCSRFileBuilder { size_t nelem = buffer_rptr[i+1] - buffer_rptr[i]; if (nelem != 0) { utils::Assert(saved_offset[i] < rptr[i+1], "data exceed bound"); - fo->Seek((rptr[i] + saved_offset[i]) * sizeof(IndexType) + sizeof(SizeType)); + fo->Seek((rptr[i] + saved_offset[i]) * sizeof(IndexType) + begin_data); fo->Write(&buffer_data[0] + buffer_rptr[i], nelem * sizeof(IndexType)); saved_offset[i] += nelem; } @@ -219,6 +241,8 @@ struct SparseCSRFileBuilder { std::vector rptr; /*! \brief saved top space of each item */ std::vector saved_offset; + /*! \brief beginning position of data */ + size_t begin_data; // ----- the following are buffer space /*! \brief maximum size of content buffer*/ size_t buffer_size; From a89e3063e6cc327f1552cdea154864fa510f8040 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 2 Sep 2014 15:34:11 -0700 Subject: [PATCH 14/19] untested version of cpage --- src/io/io.cpp | 11 ++++++++++ src/io/page_dmatrix-inl.hpp | 44 +++++++++++++++++++++++-------------- src/io/page_fmatrix-inl.hpp | 32 ++++++++++++++++++++++----- 3 files changed, 65 insertions(+), 22 deletions(-) diff --git a/src/io/io.cpp b/src/io/io.cpp index c2d9e26d3..faed31f13 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -7,6 +7,7 @@ using namespace std; #include "../utils/utils.h" #include "simple_dmatrix-inl.hpp" #include "page_dmatrix-inl.hpp" +#include "page_fmatrix-inl.hpp" // implements data loads using dmatrix simple for now @@ -30,6 +31,12 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { // the file pointer is hold in page matrix return dmat; } + if (magic == DMatrixColPage::kMagic) { + DMatrixColPage *dmat = new DMatrixColPage(fname); + dmat->Load(fs, silent, fname); + // the file pointer is hold in page matrix + return dmat; + } fs.Close(); DMatrixSimple *dmat = new DMatrixSimple(); @@ -42,6 +49,10 @@ void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) { DMatrixPage::Save(fname, dmat, silent); return; } + if (!strcmp(fname + strlen(fname) - 6, ".cpage")) { + DMatrixColPage::Save(fname, dmat, silent); + return; + } if (dmat.magic == DMatrixSimple::kMagic) { const DMatrixSimple *p_dmat = static_cast(&dmat); p_dmat->SaveBinary(fname, silent); diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 83c745599..63010d882 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -9,7 +9,6 @@ #include "../utils/iterator.h" #include "../utils/thread_buffer.h" #include "./simple_fmatrix-inl.hpp" -#include "./page_fmatrix-inl.hpp" namespace xgboost { namespace io { @@ -200,26 +199,24 @@ class ThreadRowPageIterator: public utils::IIterator { }; /*! \brief data matrix using page */ -class DMatrixPage : public DataMatrix { +template +class DMatrixPageBase : public DataMatrix { public: - DMatrixPage(void) : DataMatrix(kMagic) { + DMatrixPageBase(void) : DataMatrix(kMagic) { iter_ = new ThreadRowPageIterator(); - fmat_ = new FMatrixS(iter_); } // virtual destructor - virtual ~DMatrixPage(void) { - delete fmat_; - } - virtual IFMatrix *fmat(void) const { - return fmat_; + virtual ~DMatrixPageBase(void) { + // do not delete row iterator, since it is owned by fmat + // to be cleaned up in a more clear way } /*! \brief load and initialize the iterator with fi */ inline void Load(utils::FileStream &fi, bool silent = false, const char *fname = NULL){ - int magic; - utils::Check(fi.Read(&magic, sizeof(magic)) != 0, "invalid input file format"); - utils::Check(magic == kMagic, "invalid format,magic number mismatch"); + int tmagic; + utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format"); + utils::Check(tmagic == magic, "invalid format,magic number mismatch"); this->info.LoadBinary(fi); iter_->Load(fi); if (!silent) { @@ -250,12 +247,27 @@ class DMatrixPage : public DataMatrix { static_cast(mat.info.num_col()), fname); } } - /*! \brief the real fmatrix */ - FMatrixS *fmat_; + /*! \brief magic number used to identify DMatrix */ + static const int kMagic = TKMagic; + protected: + /*! \brief row iterator */ ThreadRowPageIterator *iter_; - /*! \brief magic number used to identify DMatrix */ - static const int kMagic = 0xffffab02; +}; + +class DMatrixPage : public DMatrixPageBase<0xffffab02> { + public: + DMatrixPage(void) { + fmat_ = new FMatrixS(iter_); + } + virtual ~DMatrixPage(void) { + delete fmat_; + } + virtual IFMatrix *fmat(void) const { + return fmat_; + } + /*! \brief the real fmatrix */ + IFMatrix *fmat_; }; } // namespace io } // namespace xgboost diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index 7e9903be4..4189c0c85 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -199,7 +199,8 @@ class ThreadColPageIterator : public utils::IIterator { class FMatrixPage : public IFMatrix { public: /*! \brief constructor */ - FMatrixPage(utils::IIterator *iter, std::string fname_buffer) { + FMatrixPage(utils::IIterator *iter, std::string fname_buffer) + : fname_cbuffer_(fname_buffer) { this->row_iter_ = iter; this->col_iter_ = NULL; this->fi_ = NULL; @@ -238,7 +239,8 @@ class FMatrixPage : public IFMatrix { } virtual void InitColAccess(float pkeep = 1.0f) { if (this->HaveColAccess()) return; - this->InitColData(pkeep); + this->InitColData(pkeep, fname_cbuffer_.c_str(), + 64 << 20, 5); } /*! * \brief get the row iterator associated with FMatrix @@ -281,11 +283,12 @@ class FMatrixPage : public IFMatrix { * \brief intialize column data * \param pkeep probability to keep a row */ - inline void InitColData(float pkeep) { + inline void InitColData(float pkeep, const char *fname, + size_t buffer_size, size_t col_step) { buffered_rowset_.clear(); - utils::FileStream fo(utils::FopenCheck(fname_cbuffer_.c_str(), "wb+")); + utils::FileStream fo(utils::FopenCheck(fname, "wb+")); // use 64M buffer - utils::SparseCSRFileBuilder builder(&fo, 64<<20); + utils::SparseCSRFileBuilder builder(&fo, buffer_size); // start working row_iter_->BeforeFirst(); @@ -322,7 +325,7 @@ class FMatrixPage : public IFMatrix { } } builder.Finalize(); - builder.SortRows(ColBatch::Entry::CmpValue, 5); + builder.SortRows(ColBatch::Entry::CmpValue, col_step); fo.Close(); } @@ -339,6 +342,23 @@ class FMatrixPage : public IFMatrix { std::vector buffered_rowset_; }; +class DMatrixColPage : public DMatrixPageBase<0xffffab03> { + public: + DMatrixColPage(const char *fname) { + std::string fext = fname; + fext += ".col"; + fmat_ = new FMatrixPage(iter_, fext.c_str()); + } + virtual ~DMatrixColPage(void) { + delete fmat_; + } + virtual IFMatrix *fmat(void) const { + return fmat_; + } + /*! \brief the real fmatrix */ + IFMatrix *fmat_; +}; + } // namespace io } // namespace xgboost #endif // XGBOOST_IO_PAGE_FMATRIX_INL_HPP_ From 226d26d40c7a7c44e607dab2a7ae476b3e15fd58 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 2 Sep 2014 17:18:17 -0700 Subject: [PATCH 15/19] still buggy --- src/io/page_fmatrix-inl.hpp | 9 +++++++-- src/utils/matrix_csr.h | 21 ++++++++++++--------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index 4189c0c85..9e586e1c4 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -132,6 +132,7 @@ class CSCMatrixManager { "invalid column buffer format"); p_page->col_data.push_back(ColBatch::Inst(p_data, len)); p_page->col_index.push_back(cidx); + return true; } // the following are in memory auxiliary data structure /*! \brief top of reader position */ @@ -159,6 +160,7 @@ class ThreadColPageIterator : public utils::IIterator { float page_ratio, bool silent) { itr_.SetParam("buffer_size", "2"); itr_.get_factory().Setup(fi, page_ratio); + itr_.Init(); if (!silent) { utils::Printf("ThreadColPageIterator: finish initialzing, %u columns\n", static_cast(col_ptr().size() - 1)); @@ -239,8 +241,11 @@ class FMatrixPage : public IFMatrix { } virtual void InitColAccess(float pkeep = 1.0f) { if (this->HaveColAccess()) return; - this->InitColData(pkeep, fname_cbuffer_.c_str(), - 64 << 20, 5); + if (!this->LoadColData()) { + this->InitColData(pkeep, fname_cbuffer_.c_str(), + 64 << 20, 5); + utils::Check(this->LoadColData(), "fail to read in column data"); + } } /*! * \brief get the row iterator associated with FMatrix diff --git a/src/utils/matrix_csr.h b/src/utils/matrix_csr.h index e4c410511..ea5bc8b2d 100644 --- a/src/utils/matrix_csr.h +++ b/src/utils/matrix_csr.h @@ -6,6 +6,7 @@ * \author Tianqi Chen */ #include +#include #include #include "./io.h" #include "./utils.h" @@ -156,7 +157,7 @@ struct SparseCSRFileBuilder { for (size_t i = 1; i < rptr.size(); i++) { nelem += rptr[i]; rptr[i] = nelem; - } + } begin_data = static_cast(fo->Tell()) + sizeof(SizeType); SizeType begin_meta = begin_data + nelem * sizeof(IndexType); fo->Write(&begin_meta, sizeof(begin_meta)); @@ -166,8 +167,8 @@ struct SparseCSRFileBuilder { buffer_rptr.resize(rptr.size()); buffer_temp.reserve(buffer_size); buffer_data.resize(buffer_size); - saved_offset.clear(); - saved_offset.resize(rptr.size() - 1, 0); + saved_offset = rptr; + saved_offset.resize(rptr.size() - 1); this->ClearBuffer(); } /*! \brief step 4: push element into buffer */ @@ -176,7 +177,8 @@ struct SparseCSRFileBuilder { this->WriteBuffer(); this->ClearBuffer(); } - buffer_temp.push_back(std::make_pair(row_id, col_id)); + buffer_rptr[row_id + 1] += 1; + buffer_temp.push_back(std::make_pair(row_id, col_id)); } /*! \brief finalize the construction */ inline void Finalize(void) { @@ -190,14 +192,14 @@ struct SparseCSRFileBuilder { inline void SortRows(Comparator comp, size_t step) { for (size_t i = 0; i < rptr.size() - 1; i += step) { bst_omp_uint begin = static_cast(i); - bst_omp_uint end = static_cast(std::min(rptr.size(), i + step)); + bst_omp_uint end = static_cast(std::min(rptr.size() - 1, i + step)); if (rptr[end] != rptr[begin]) { fo->Seek(begin_data + rptr[begin] * sizeof(IndexType)); buffer_data.resize(rptr[end] - rptr[begin]); fo->Read(BeginPtr(buffer_data), (rptr[end] - rptr[begin]) * sizeof(IndexType)); // do parallel sorting #pragma omp parallel for schedule(static) - for (bst_omp_uint j = begin; j < end; ++j){ + for (bst_omp_uint j = begin; j < end; ++j) { std::sort(&buffer_data[0] + rptr[j] - rptr[begin], &buffer_data[0] + rptr[j+1] - rptr[begin], comp); @@ -206,6 +208,7 @@ struct SparseCSRFileBuilder { fo->Write(BeginPtr(buffer_data), (rptr[end] - rptr[begin]) * sizeof(IndexType)); } } + printf("CSV::begin_dat=%lu\n", begin_data); } protected: inline void WriteBuffer(void) { @@ -220,11 +223,11 @@ struct SparseCSRFileBuilder { buffer_data[rp++] = buffer_temp[i].second; } // write out - for (size_t i = 0; i < buffer_rptr.size(); ++i) { + for (size_t i = 0; i < buffer_rptr.size() - 1; ++i) { size_t nelem = buffer_rptr[i+1] - buffer_rptr[i]; if (nelem != 0) { - utils::Assert(saved_offset[i] < rptr[i+1], "data exceed bound"); - fo->Seek((rptr[i] + saved_offset[i]) * sizeof(IndexType) + begin_data); + utils::Assert(saved_offset[i] + nelem <= rptr[i+1], "data exceed bound"); + fo->Seek(saved_offset[i] * sizeof(IndexType) + begin_data); fo->Write(&buffer_data[0] + buffer_rptr[i], nelem * sizeof(IndexType)); saved_offset[i] += nelem; } From f3360d173b1eaf2a703c1cf5b7345513734f3fa7 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 2 Sep 2014 17:38:51 -0700 Subject: [PATCH 16/19] pass trival test --- src/io/page_fmatrix-inl.hpp | 21 +++++++++++---------- src/tree/updater_colmaker-inl.hpp | 3 ++- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index 9e586e1c4..22766ab65 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -46,9 +46,10 @@ class CSCMatrixManager { } /*! \brief get underlying batch */ inline ColBatch GetBatch(void) const { - ColBatch batch; - batch.col_index = &col_index[0]; - batch.col_data = &col_data[0]; + ColBatch batch; + batch.size = col_index.size(); + batch.col_index = BeginPtr(col_index); + batch.col_data = BeginPtr(col_data); return batch; } private: @@ -79,11 +80,13 @@ class CSCMatrixManager { col_index_ = col_todo_; read_top_ = 0; } - inline bool LoadNext(PagePtr &val) { + inline bool LoadNext(PagePtr &val) { val->Clear(); if (read_top_ >= col_index_.size()) return false; while (read_top_ < col_index_.size()) { - if (!this->TryFill(col_index_[read_top_], val)) return true; + if (!this->TryFill(col_index_[read_top_], val)) { + return true; + } ++read_top_; } return true; @@ -241,11 +244,9 @@ class FMatrixPage : public IFMatrix { } virtual void InitColAccess(float pkeep = 1.0f) { if (this->HaveColAccess()) return; - if (!this->LoadColData()) { - this->InitColData(pkeep, fname_cbuffer_.c_str(), - 64 << 20, 5); - utils::Check(this->LoadColData(), "fail to read in column data"); - } + this->InitColData(pkeep, fname_cbuffer_.c_str(), + 64 << 20, 5); + utils::Check(this->LoadColData(), "fail to read in column data"); } /*! * \brief get the row iterator associated with FMatrix diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index bf93cb7b5..12f808ce4 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -421,7 +421,7 @@ class ColMaker: public IUpdater { for (bst_omp_uint i = 0; i < nsize; ++i) { const bst_uint fid = batch.col_index[i]; const int tid = omp_get_thread_num(); - const ColBatch::Inst c = batch[i]; + const ColBatch::Inst c = batch[i]; if (param.need_forward_search(fmat.GetColDensity(fid))) { this->EnumerateSplit(c.data, c.data + c.length, +1, fid, gpair, info, stemp[tid]); @@ -452,6 +452,7 @@ class ColMaker: public IUpdater { utils::Check(n > 0, "colsample_bylevel is too small that no feature can be included"); feat_set.resize(n); } + std::sort(feat_set.begin(), feat_set.end()); utils::IIterator *iter = p_fmat->ColIterator(feat_set); while (iter->Next()) { this->UpdateSolution(iter->Value(), gpair, *p_fmat, info); From e6e467ad6093533c6130746e1755346ef1b4fbb8 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 2 Sep 2014 17:40:30 -0700 Subject: [PATCH 17/19] more ignore --- .gitignore | 2 ++ src/tree/updater_colmaker-inl.hpp | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 220fc602a..1a2a4b48e 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,5 @@ Debug *dump *save *csv +*.cpage.col +*.cpage diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index 12f808ce4..566e57752 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -452,7 +452,6 @@ class ColMaker: public IUpdater { utils::Check(n > 0, "colsample_bylevel is too small that no feature can be included"); feat_set.resize(n); } - std::sort(feat_set.begin(), feat_set.end()); utils::IIterator *iter = p_fmat->ColIterator(feat_set); while (iter->Next()) { this->UpdateSolution(iter->Value(), gpair, *p_fmat, info); From 401d648372bcae81346e4ca5226a640abf073e41 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 2 Sep 2014 17:49:39 -0700 Subject: [PATCH 18/19] some lint --- src/io/page_dmatrix-inl.hpp | 37 +++++++++++++++++++------------------ src/io/page_fmatrix-inl.hpp | 36 ++++++++++++++++++++---------------- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 63010d882..7a0781621 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -5,6 +5,7 @@ * row iterator based on sparse page * \author Tianqi Chen */ +#include #include "../data.h" #include "../utils/iterator.h" #include "../utils/thread_buffer.h" @@ -15,7 +16,7 @@ namespace io { /*! \brief page structure that can be used to store a rowbatch */ struct RowBatchPage { public: - RowBatchPage(size_t page_size) : kPageSize(page_size) { + explicit RowBatchPage(size_t page_size) : kPageSize(page_size) { data_ = new int[kPageSize]; utils::Assert(data_ != NULL, "fail to allocate row batch page"); this->Clear(); @@ -31,10 +32,10 @@ struct RowBatchPage { inline bool PushRow(const RowBatch::Inst &row) { const size_t dsize = row.length * sizeof(RowBatch::Entry); if (FreeBytes() < dsize+ sizeof(int)) return false; - row_ptr(Size() + 1) = row_ptr(Size()) + row.length; + row_ptr(Size() + 1) = row_ptr(Size()) + row.length; memcpy(data_ptr(row_ptr(Size())) , row.data, dsize); - ++ data_[0]; - return true; + ++data_[0]; + return true; } /*! * \brief get a row batch representation from the page @@ -43,7 +44,7 @@ struct RowBatchPage { * \return a new RowBatch object */ inline RowBatch GetRowBatch(std::vector *p_rptr, size_t base_rowid) { - RowBatch batch; + RowBatch batch; batch.base_rowid = base_rowid; batch.data_ptr = this->data_ptr(0); batch.size = static_cast(this->Size()); @@ -57,7 +58,7 @@ struct RowBatchPage { } /*! \brief get i-th row from the batch */ inline RowBatch::Inst operator[](int i) { - return RowBatch::Inst(data_ptr(0) + row_ptr(i), + return RowBatch::Inst(data_ptr(0) + row_ptr(i), static_cast(row_ptr(i+1) - row_ptr(i))); } /*! @@ -85,8 +86,8 @@ struct RowBatchPage { private: /*! \return number of elements */ inline size_t FreeBytes(void) { - return (kPageSize - (Size() + 2)) * sizeof(int) - - row_ptr(Size()) * sizeof(RowBatch::Entry) ; + return (kPageSize - (Size() + 2)) * sizeof(int) - + row_ptr(Size()) * sizeof(RowBatch::Entry); } /*! \brief equivalent row pointer at i */ inline int& row_ptr(int i) { @@ -98,7 +99,7 @@ struct RowBatchPage { // page size const size_t kPageSize; // content of data - int *data_; + int *data_; }; /*! \brief thread buffer iterator */ class ThreadRowPageIterator: public utils::IIterator { @@ -108,8 +109,7 @@ class ThreadRowPageIterator: public utils::IIterator { page_ = NULL; base_rowid_ = 0; } - virtual ~ThreadRowPageIterator(void) { - } + virtual ~ThreadRowPageIterator(void) {} virtual void Init(void) { } virtual void BeforeFirst(void) { @@ -117,12 +117,12 @@ class ThreadRowPageIterator: public utils::IIterator { base_rowid_ = 0; } virtual bool Next(void) { - if(!itr.Next(page_)) return false; + if (!itr.Next(page_)) return false; out_ = page_->GetRowBatch(&tmp_ptr_, base_rowid_); base_rowid_ += out_.size; return true; } - virtual const RowBatch &Value(void) const{ + virtual const RowBatch &Value(void) const { return out_; } /*! \brief load and initialize the iterator with fi */ @@ -152,6 +152,7 @@ class ThreadRowPageIterator: public utils::IIterator { } /*! \brief page size 64 MB */ static const size_t kPageSize = 64 << 18; + private: // base row id size_t base_rowid_; @@ -195,7 +196,7 @@ class ThreadRowPageIterator: public utils::IIterator { protected: PagePtr page_; - utils::ThreadBuffer itr; + utils::ThreadBuffer itr; }; /*! \brief data matrix using page */ @@ -213,10 +214,10 @@ class DMatrixPageBase : public DataMatrix { /*! \brief load and initialize the iterator with fi */ inline void Load(utils::FileStream &fi, bool silent = false, - const char *fname = NULL){ + const char *fname = NULL) { int tmagic; utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format"); - utils::Check(tmagic == magic, "invalid format,magic number mismatch"); + utils::Check(tmagic == magic, "invalid format,magic number mismatch"); this->info.LoadBinary(fi); iter_->Load(fi); if (!silent) { @@ -229,7 +230,7 @@ class DMatrixPageBase : public DataMatrix { utils::Printf("\n"); } if (info.group_ptr.size() != 0) { - utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1); + utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size() - 1); } } } @@ -249,8 +250,8 @@ class DMatrixPageBase : public DataMatrix { } /*! \brief magic number used to identify DMatrix */ static const int kMagic = TKMagic; - protected: + protected: /*! \brief row iterator */ ThreadRowPageIterator *iter_; }; diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index 22766ab65..327e5c144 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -5,6 +5,9 @@ * sparse page manager for fmatrix * \author Tianqi Chen */ +#include +#include +#include #include "../data.h" #include "../utils/iterator.h" #include "../utils/io.h" @@ -34,7 +37,7 @@ class CSCMatrixManager { /*! \brief column index */ std::vector col_index; /*! \brief column data */ - std::vector col_data; + std::vector col_data; /*! \brief number of free entries */ inline size_t NumFreeEntry(void) const { return buffer.size() - num_entry; @@ -52,6 +55,7 @@ class CSCMatrixManager { batch.col_data = BeginPtr(col_data); return batch; } + private: /*! \brief buffer space, not to be changed since ready */ std::vector buffer; @@ -80,7 +84,7 @@ class CSCMatrixManager { col_index_ = col_todo_; read_top_ = 0; } - inline bool LoadNext(PagePtr &val) { + inline bool LoadNext(PagePtr &val) { val->Clear(); if (read_top_ >= col_index_.size()) return false; while (read_top_ < col_index_.size()) { @@ -106,7 +110,7 @@ class CSCMatrixManager { psmax = std::max(psmax, col_ptr_[i+1] - col_ptr_[i]); } utils::Check(page_ratio >= 1.0f, "col_page_ratio must be at least 1"); - page_size_ = std::max(static_cast(psmax * page_ratio), psmax); + page_size_ = std::max(static_cast(psmax * page_ratio), psmax); } inline void SetColSet(const std::vector &cset, bool setall) { if (!setall) { @@ -124,6 +128,7 @@ class CSCMatrixManager { } } } + private: /*! \brief fill a page with */ inline bool TryFill(size_t cidx, Page *p_page) { @@ -173,21 +178,22 @@ class ThreadColPageIterator : public utils::IIterator { } virtual void BeforeFirst(void) { itr_.BeforeFirst(); - } + } virtual bool Next(void) { // page to be loaded CSCMatrixManager::PagePtr page; - if(!itr_.Next(page)) return false; + if (!itr_.Next(page)) return false; out_ = page->GetBatch(); return true; } - virtual const ColBatch &Value(void) const{ + virtual const ColBatch &Value(void) const { return out_; } inline const std::vector &col_ptr(void) const { return itr_.get_factory().col_ptr(); } - inline void SetColSet(const std::vector &cset, bool setall = false) { + inline void SetColSet(const std::vector &cset, + bool setall = false) { itr_.get_factory().SetColSet(cset, setall); } @@ -195,9 +201,8 @@ class ThreadColPageIterator : public utils::IIterator { // output data ColBatch out_; // internal iterator - utils::ThreadBuffer itr_; + utils::ThreadBuffer itr_; }; - /*! * \brief sparse matrix that support column access */ @@ -216,7 +221,7 @@ class FMatrixPage : public IFMatrix { if (col_iter_ != NULL) delete col_iter_; if (fi_ != NULL) { fi_->Close(); delete fi_; - } + } } /*! \return whether column access is enabled */ virtual bool HaveColAccess(void) const { @@ -272,7 +277,7 @@ class FMatrixPage : public IFMatrix { col_iter_->BeforeFirst(); return col_iter_; } - + protected: /*! * \brief try load column data from file @@ -282,25 +287,24 @@ class FMatrixPage : public IFMatrix { if (fp == NULL) return false; fi_ = new utils::FileStream(fp); static_cast(fi_)->Read(&buffered_rowset_); - col_iter_ = new ThreadColPageIterator(fi_, 2.0f, false); + col_iter_ = new ThreadColPageIterator(fi_, 2.0f, false); return true; } /*! * \brief intialize column data * \param pkeep probability to keep a row */ - inline void InitColData(float pkeep, const char *fname, + inline void InitColData(float pkeep, const char *fname, size_t buffer_size, size_t col_step) { buffered_rowset_.clear(); utils::FileStream fo(utils::FopenCheck(fname, "wb+")); // use 64M buffer utils::SparseCSRFileBuilder builder(&fo, buffer_size); - // start working row_iter_->BeforeFirst(); while (row_iter_->Next()) { const RowBatch &batch = row_iter_->Value(); - for (size_t i = 0; i < batch.size; ++i) { + for (size_t i = 0; i < batch.size; ++i) { if (pkeep == 1.0f || random::SampleBinary(pkeep)) { buffered_rowset_.push_back(static_cast(batch.base_rowid+i)); RowBatch::Inst inst = batch[i]; @@ -350,7 +354,7 @@ class FMatrixPage : public IFMatrix { class DMatrixColPage : public DMatrixPageBase<0xffffab03> { public: - DMatrixColPage(const char *fname) { + explicit DMatrixColPage(const char *fname) { std::string fext = fname; fext += ".col"; fmat_ = new FMatrixPage(iter_, fext.c_str()); From 244a589e5de7f1df7e4418a5fec28a3604217353 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 3 Sep 2014 11:31:05 -0700 Subject: [PATCH 19/19] change include order, so that Rinternal does not disturb us --- R-package/src/xgboost_R.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R-package/src/xgboost_R.cpp b/R-package/src/xgboost_R.cpp index a7753dfa5..9171948eb 100644 --- a/R-package/src/xgboost_R.cpp +++ b/R-package/src/xgboost_R.cpp @@ -3,12 +3,13 @@ #include #include #include -#include "xgboost_R.h" #include "wrapper/xgboost_wrapper.h" #include "src/utils/utils.h" #include "src/utils/omp.h" #include "src/utils/matrix_csr.h" -using namespace std; + +#include "xgboost_R.h" + using namespace xgboost; extern "C" {