diff --git a/src/io/io.cpp b/src/io/io.cpp index d251d7a96..f56cff679 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -5,6 +5,8 @@ #include "../utils/io.h" #include "../utils/utils.h" #include "simple_dmatrix-inl.hpp" +#include "page_dmatrix-inl.hpp" + // implements data loads using dmatrix simple for now namespace xgboost { diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp new file mode 100644 index 000000000..82a373352 --- /dev/null +++ b/src/io/page_dmatrix-inl.hpp @@ -0,0 +1,204 @@ +#ifndef XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_ +#define XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_ +/*! + * \file page_row_iter-inl.hpp + * row iterator based on sparse page + * \author Tianqi Chen + */ +#include "../data.h" +#include "../utils/iterator.h" +#include "../utils/thread_buffer.h" +namespace xgboost { +namespace io { +/*! \brief page structure that can be used to store a rowbatch */ +struct RowBatchPage { + public: + RowBatchPage(void) { + data_ = new int[kPageSize]; + utils::Assert(data_ != NULL, "fail to allocate row batch page"); + this->Clear(); + } + ~RowBatchPage(void) { + if (data_ != NULL) delete [] data_; + } + /*! + * \brief Push one row into page + * \param row an instance row + * \return false or true to push into + */ + inline bool PushRow(const RowBatch::Inst &row) { + const size_t dsize = row.length * sizeof(RowBatch::Entry); + if (FreeBytes() < dsize+ sizeof(int)) return false; + row_ptr(Size() + 1) = row_ptr(Size()) + row.length; + memcpy(data_ptr(Size()) , row.data, dsize); + ++ data_[0]; + return true; + } + /*! + * \brief get a row batch representation from the page + * \param p_rptr a temporal space that can be used to provide + * ind_ptr storage for RowBatch + * \return a new RowBatch object + */ + inline RowBatch GetRowBatch(std::vector *p_rptr, size_t base_rowid) { + RowBatch batch; + batch.base_rowid = base_rowid; + batch.data_ptr = this->data_ptr(0); + batch.size = static_cast(this->Size()); + std::vector &rptr = *p_rptr; + rptr.resize(this->Size()+1); + for (size_t i = 0; i < rptr.size(); ++i) { + rptr[i] = static_cast(this->row_ptr(i)); + } + batch.ind_ptr = &rptr[0]; + return batch; + } + /*! + * \brief clear the page, cleanup the content + */ + inline void Clear(void) { + memset(&data_[0], 0, sizeof(int) * kPageSize); + } + /*! + * \brief load one page form instream + * \return true if loading is successful + */ + inline bool Load(utils::IStream &fi) { + return fi.Read(&data_[0], sizeof(int) * kPageSize) != 0; + } + /*! \brief save one page into outstream */ + inline void Save(utils::IStream &fo) { + fo.Write(&data_[0], sizeof(int) * kPageSize); + } + /*! \return number of elements */ + inline int Size(void) const { + return data_[0]; + } + /*! \brief page size 64 MB */ + static const size_t kPageSize = 64 << 18; + + private: + /*! \return number of elements */ + inline size_t FreeBytes(void) { + return (kPageSize - (Size() + 2)) * sizeof(int) + - row_ptr(Size()) * sizeof(RowBatch::Entry) ; + } + /*! \brief equivalent row pointer at i */ + inline int& row_ptr(int i) { + return data_[kPageSize - i - 1]; + } + inline RowBatch::Entry* data_ptr(int i) { + return (RowBatch::Entry*)(&data_[1]) + i; + } + // content of data + int *data_; +}; +/*! \brief thread buffer iterator */ +class ThreadRowPageIterator: public utils::IIterator { + public: + ThreadRowPageIterator(void) { + itr.SetParam("buffer_size", "4"); + page_ = NULL; + base_rowid_ = 0; + isend_ = false; + } + virtual ~ThreadRowPageIterator(void) { + } + virtual void Init(void) { + } + virtual void BeforeFirst(void) { + itr.BeforeFirst(); + isend_ = false; + base_rowid_ = 0; + utils::Assert(this->LoadNextPage(), "ThreadRowPageIterator"); + } + virtual bool Next(void) { + if(!this->LoadNextPage()) return false; + out_ = page_->GetRowBatch(&tmp_ptr_, base_rowid_); + base_rowid_ += out_.size; + return true; + } + virtual const RowBatch &Value(void) const{ + return out_; + } + /*! \brief load and initialize the iterator with fi */ + inline void Load(const utils::FileStream &fi) { + itr.get_factory().SetFile(fi); + itr.Init(); + this->BeforeFirst(); + } + /*! + * \brief save a row iterator to output stream, in row iterator format + */ + inline static void Save(utils::IIterator *iter, + utils::IStream &fo) { + RowBatchPage page; + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch &batch = iter->Value(); + for (size_t i = 0; i < batch.size; ++i) { + if (!page.PushRow(batch[i])) { + page.Save(fo); + page.Clear(); + utils::Check(page.PushRow(batch[i]), "row is too big"); + } + } + } + if (page.Size() != 0) page.Save(fo); + } + private: + // load in next page + inline bool LoadNextPage(void) { + ptop_ = 0; + bool ret = itr.Next(page_); + isend_ = !ret; + return ret; + } + // base row id + size_t base_rowid_; + // temporal ptr + std::vector tmp_ptr_; + // output data + RowBatch out_; + // whether we reach end of file + bool isend_; + // page pointer type + typedef RowBatchPage* PagePtr; + // loader factory for page + struct Factory { + public: + size_t file_begin_; + utils::FileStream fi; + Factory(void) {} + inline void SetFile(const utils::FileStream &fi) { + this->fi = fi; + file_begin_ = this->fi.Tell(); + } + inline bool Init(void) { + return true; + } + inline void SetParam(const char *name, const char *val) {} + inline bool LoadNext(PagePtr &val) { + return val->Load(fi); + } + inline PagePtr Create(void) { + PagePtr a = new RowBatchPage(); + return a; + } + inline void FreeSpace(PagePtr &a) { + delete a; + } + inline void Destroy(void) {} + inline void BeforeFirst(void) { + fi.Seek(file_begin_); + } + }; + + protected: + PagePtr page_; + int ptop_; + utils::ThreadBuffer itr; +}; +} // namespace io +} // namespace xgboost +#endif // XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_ diff --git a/src/utils/io.h b/src/utils/io.h index 4a80e9a58..23fa0d468 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -88,11 +88,19 @@ class IStream { } }; -/*! \brief implementation of file i/o stream */ -class FileStream : public IStream { - private: - FILE *fp; +/*! \brief interface of i/o stream that support seek */ +class ISeekStream: public IStream { public: + /*! \brief seek to certain position of the file */ + virtual void Seek(size_t pos) = 0; + /*! \brief tell the position of the stream */ + virtual size_t Tell(void) = 0; +}; + +/*! \brief implementation of file i/o stream */ +class FileStream : public ISeekStream { + public: + explicit FileStream(void) {} explicit FileStream(FILE *fp) { this->fp = fp; } @@ -102,12 +110,18 @@ class FileStream : public IStream { virtual void Write(const void *ptr, size_t size) { fwrite(ptr, size, 1, fp); } - inline void Seek(size_t pos) { - fseek(fp, 0, SEEK_SET); + virtual void Seek(size_t pos) { + fseek(fp, pos, SEEK_SET); + } + virtual size_t Tell(void) { + return static_cast(ftell(fp)); } inline void Close(void) { fclose(fp); } + + private: + FILE *fp; }; } // namespace utils diff --git a/src/utils/thread.h b/src/utils/thread.h new file mode 100644 index 000000000..830b21cbe --- /dev/null +++ b/src/utils/thread.h @@ -0,0 +1,146 @@ +#ifndef XGBOOST_UTILS_THREAD_H +#define XGBOOST_UTILS_THREAD_H +/*! + * \file thread.h + * \brief this header include the minimum necessary resource for multi-threading + * \author Tianqi Chen + * Acknowledgement: this file is adapted from SVDFeature project, by same author. + * The MAC support part of this code is provided by Artemy Kolchinsky + */ +#ifdef _MSC_VER +#include "utils.h" +#include +#include +namespace xgboost { +namespace utils { +/*! \brief simple semaphore used for synchronization */ +class Semaphore { + public : + inline void Init(int init_val) { + sem = CreateSemaphore(NULL, init_val, 10, NULL); + utils::Assert(sem != NULL, "create Semaphore error"); + } + inline void Destroy(void) { + CloseHandle(sem); + } + inline void Wait(void) { + utils::Assert(WaitForSingleObject(sem, INFINITE) == WAIT_OBJECT_0, "WaitForSingleObject error"); + } + inline void Post(void) { + utils::Assert(ReleaseSemaphore(sem, 1, NULL) != 0, "ReleaseSemaphore error"); + } + private: + HANDLE sem; +}; +/*! \brief simple thread that wraps windows thread */ +class Thread { + private: + HANDLE thread_handle; + unsigned thread_id; + public: + inline void Start(unsigned int __stdcall entry(void*), void *param) { + thread_handle = (HANDLE)_beginthreadex(NULL, 0, entry, param, 0, &thread_id); + } + inline int Join(void) { + WaitForSingleObject(thread_handle, INFINITE); + return 0; + } +}; +/*! \brief exit function called from thread */ +inline void ThreadExit(void *status) { + _endthreadex(0); +} +#define XGBOOST_THREAD_PREFIX unsigned int __stdcall +} // namespace utils +} // namespace xgboost +#else +// thread interface using g++ +#include +#include +namespace xgboost { +namespace utils { +/*!\brief semaphore class */ +class Semaphore { + #ifdef __APPLE__ + private: + sem_t* semPtr; + char sema_name[20]; + private: + inline void GenRandomString(char *s, const int len) { + static const char alphanum[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" ; + for (int i = 0; i < len; ++i) { + s[i] = alphanum[rand() % (sizeof(alphanum) - 1)]; + } + s[len] = 0; + } + public: + inline void Init(int init_val) { + sema_name[0]='/'; + sema_name[1]='s'; + sema_name[2]='e'; + sema_name[3]='/'; + GenRandomString(&sema_name[4], 16); + if((semPtr = sem_open(sema_name, O_CREAT, 0644, init_val)) == SEM_FAILED) { + perror("sem_open"); + exit(1); + } + utils::Assert(semPtr != NULL, "create Semaphore error"); + } + inline void Destroy(void) { + if (sem_close(semPtr) == -1) { + perror("sem_close"); + exit(EXIT_FAILURE); + } + if (sem_unlink(sema_name) == -1) { + perror("sem_unlink"); + exit(EXIT_FAILURE); + } + } + inline void Wait(void) { + sem_wait(semPtr); + } + inline void Post(void) { + sem_post(semPtr); + } + #else + private: + sem_t sem; + public: + inline void Init(int init_val) { + sem_init(&sem, 0, init_val); + } + inline void Destroy(void) { + sem_destroy(&sem); + } + inline void Wait(void) { + sem_wait(&sem); + } + inline void Post(void) { + sem_post(&sem); + } + #endif +}; +/*!\brief simple thread class */ +class Thread { + private: + pthread_t thread; + public : + inline void Start(void * entry(void*), void *param) { + pthread_attr_t attr; + pthread_attr_init(&attr); + pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE); + pthread_create(&thread, &attr, entry, param); + } + inline int Join(void) { + void *status; + return pthread_join(thread, &status); + } +}; +inline void ThreadExit(void *status) { + pthread_exit(status); +} +} // namespace utils +} // namespace xgboost +#define XGBOOST_THREAD_PREFIX void * +#endif +#endif diff --git a/src/utils/thread_buffer.h b/src/utils/thread_buffer.h new file mode 100644 index 000000000..fa488a220 --- /dev/null +++ b/src/utils/thread_buffer.h @@ -0,0 +1,200 @@ +#ifndef XGBOOST_UTILS_THREAD_BUFFER_H +#define XGBOOST_UTILS_THREAD_BUFFER_H +/*! + * \file thread_buffer.h + * \brief multi-thread buffer, iterator, can be used to create parallel pipeline + * \author Tianqi Chen + */ +#include +#include +#include +#include "./utils.h" +#include "./thread.h" +namespace xgboost { +namespace utils { +/*! + * \brief buffered loading iterator that uses multithread + * this template method will assume the following paramters + * \tparam Elem elememt type to be buffered + * \tparam ElemFactory factory type to implement in order to use thread buffer + */ +template +class ThreadBuffer { + public: + /*!\brief constructor */ + ThreadBuffer(void) { + this->init_end = false; + this->buf_size = 30; + } + ~ThreadBuffer(void) { + if(init_end) this->Destroy(); + } + /*!\brief set parameter, will also pass the parameter to factory */ + inline void SetParam(const char *name, const char *val) { + if (!strcmp( name, "buffer_size")) buf_size = atoi(val); + factory.SetParam(name, val); + } + /*! + * \brief initalize the buffered iterator + * \param param a initialize parameter that will pass to factory, ignore it if not necessary + * \return false if the initlization can't be done, e.g. buffer file hasn't been created + */ + inline bool Init(void) { + if (!factory.Init()) return false; + for (int i = 0; i < buf_size; ++i) { + bufA.push_back(factory.Create()); + bufB.push_back(factory.Create()); + } + this->init_end = true; + this->StartLoader(); + return true; + } + /*!\brief place the iterator before first value */ + inline void BeforeFirst(void) { + // wait till last loader end + loading_end.Wait(); + // critcal zone + current_buf = 1; + factory.BeforeFirst(); + // reset terminate limit + endA = endB = buf_size; + // wake up loader for first part + loading_need.Post(); + // wait til first part is loaded + loading_end.Wait(); + // set current buf to right value + current_buf = 0; + // wake loader for next part + data_loaded = false; + loading_need.Post(); + // set buffer value + buf_index = 0; + } + /*! \brief destroy the buffer iterator, will deallocate the buffer */ + inline void Destroy(void) { + // wait until the signal is consumed + this->destroy_signal = true; + loading_need.Post(); + loader_thread.Join(); + loading_need.Destroy(); + loading_end.Destroy(); + for (size_t i = 0; i < bufA.size(); ++i) { + factory.FreeSpace(bufA[i]); + } + for (size_t i = 0; i < bufB.size(); ++i) { + factory.FreeSpace(bufB[i]); + } + bufA.clear(); bufB.clear(); + factory.Destroy(); + this->init_end = false; + } + /*! + * \brief get the next element needed in buffer + * \param elem element to store into + * \return whether reaches end of data + */ + inline bool Next(Elem &elem) { + // end of buffer try to switch + if (buf_index == buf_size) { + this->SwitchBuffer(); + buf_index = 0; + } + if (buf_index >= (current_buf ? endA : endB)) { + return false; + } + std::vector &buf = current_buf ? bufA : bufB; + elem = buf[buf_index]; + ++buf_index; + return true; + } + /*! + * \brief get the factory object + */ + inline ElemFactory &get_factory(void) { + return factory; + } + // size of buffer + int buf_size; + private: + // factory object used to load configures + ElemFactory factory; + // index in current buffer + int buf_index; + // indicate which one is current buffer + int current_buf; + // max limit of visit, also marks termination + int endA, endB; + // double buffer, one is accessed by loader + // the other is accessed by consumer + // buffer of the data + std::vector bufA, bufB; + // initialization end + bool init_end; + // singal whether the data is loaded + bool data_loaded; + // signal to kill the thread + bool destroy_signal; + // thread object + Thread loader_thread; + // signal of the buffer + Semaphore loading_end, loading_need; + /*! + * \brief slave thread + * this implementation is like producer-consumer style + */ + inline void RunLoader(void) { + while(!destroy_signal) { + // sleep until loading is needed + loading_need.Wait(); + std::vector &buf = current_buf ? bufB : bufA; + int i; + for (i = 0; i < buf_size ; ++i) { + if (!factory.LoadNext(buf[i])) { + int &end = current_buf ? endB : endA; + end = i; // marks the termination + break; + } + } + // signal that loading is done + data_loaded = true; + loading_end.Post(); + } + } + /*!\brief entry point of loader thread */ + inline static XGBOOST_THREAD_PREFIX LoaderEntry(void *pthread) { + static_cast< ThreadBuffer* >(pthread)->RunLoader(); + ThreadExit(NULL); + return NULL; + } + /*!\brief start loader thread */ + inline void StartLoader(void) { + destroy_signal = false; + // set param + current_buf = 1; + loading_need.Init(1); + loading_end .Init(0); + // reset terminate limit + endA = endB = buf_size; + loader_thread.Start(LoaderEntry, this); + // wait until first part of data is loaded + loading_end.Wait(); + // set current buf to right value + current_buf = 0; + // wake loader for next part + data_loaded = false; + loading_need.Post(); + buf_index = 0; + } + /*!\brief switch double buffer */ + inline void SwitchBuffer(void) { + loading_end.Wait(); + // loader shall be sleep now, critcal zone! + current_buf = !current_buf; + // wake up loader + data_loaded = false; + loading_need.Post(); + } +}; +} // namespace utils +} // namespace xgboost +#endif diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index e2cbdba2e..adf59c829 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -3,10 +3,11 @@ import ctypes import os # optinally have scipy sparse, though not necessary -import numpy +import numpy as np import sys import numpy.ctypeslib import scipy.sparse as scp +import random # set this line correctly if os.name == 'nt': @@ -32,16 +33,28 @@ xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p) def ctypes2numpy(cptr, length, dtype): - # convert a ctypes pointer array to numpy + """convert a ctypes pointer array to numpy array """ assert isinstance(cptr, ctypes.POINTER(ctypes.c_float)) res = numpy.zeros(length, dtype=dtype) assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]) return res -# data matrix used in xgboost class DMatrix: + """data matrix used in xgboost""" # constructor def __init__(self, data, label=None, missing=0.0, weight = None): + """ constructor of DMatrix + + Args: + data: string/numpy array/scipy.sparse + data source, string type is the path of svmlight format txt file or xgb buffer + label: list or numpy 1d array, optional + label of training data + missing: float + value in data which need to be present as missing value + weight: list or numpy 1d array, optional + weight for each instances + """ # force into void_p, mac need to pass things in as void_p if data == None: self.handle = None @@ -63,22 +76,25 @@ class DMatrix: self.set_label(label) if weight !=None: self.set_weight(weight) - # convert data from csr matrix + def __init_from_csr(self, csr): + """convert data from csr matrix""" assert len(csr.indices) == len(csr.data) self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR( (ctypes.c_ulong * len(csr.indptr))(*csr.indptr), (ctypes.c_uint * len(csr.indices))(*csr.indices), (ctypes.c_float * len(csr.data))(*csr.data), len(csr.indptr), len(csr.data))) - # convert data from numpy matrix + def __init_from_npy2d(self,mat,missing): + """convert data from numpy matrix""" data = numpy.array(mat.reshape(mat.size), dtype='float32') self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat( data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), mat.shape[0], mat.shape[1], ctypes.c_float(missing))) - # destructor + def __del__(self): + """destructor""" xglib.XGDMatrixFree(self.handle) def get_float_info(self, field): length = ctypes.c_ulong() @@ -96,16 +112,39 @@ class DMatrix: def set_uint_info(self, field, data): xglib.XGDMatrixSetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), (ctypes.c_uint*len(data))(*data), len(data)) - # load data from file + def save_binary(self, fname, silent=True): + """save DMatrix to XGBoost buffer + Args: + fname: string + name of buffer file + slient: bool, option + whether print info + Returns: + None + """ xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent)) - # set label of dmatrix + def set_label(self, label): + """set label of dmatrix + Args: + label: list + label for DMatrix + Returns: + None + """ self.set_float_info('label', label) - # set weight of each instances + def set_weight(self, weight): + """set weight of each instances + Args: + weight: float + weight for positive instance + Returns: + None + """ self.set_float_info('weight', weight) - # set initialized margin prediction + def set_base_margin(self, margin): """ set base margin of booster to start from @@ -116,31 +155,143 @@ class DMatrix: see also example/demo.py """ self.set_float_info('base_margin', margin) - # set group size of dmatrix, used for rank + def set_group(self, group): + """set group size of dmatrix, used for rank + Args: + group: + + Returns: + None + """ xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group)) - # get label from dmatrix + def get_label(self): + """get label from dmatrix + Args: + None + Returns: + list, label of data + """ return self.get_float_info('label') - # get weight from dmatrix + def get_weight(self): + """get weight from dmatrix + Args: + None + Returns: + float, weight + """ return self.get_float_info('weight') - # get base_margin from dmatrix def get_base_margin(self): + """get base_margin from dmatrix + Args: + None + Returns: + float, base margin + """ return self.get_float_info('base_margin') def num_row(self): + """get number of rows + Args: + None + Returns: + int, num rows + """ return xglib.XGDMatrixNumRow(self.handle) - # slice the DMatrix to return a new DMatrix that only contains rindex def slice(self, rindex): + """slice the DMatrix to return a new DMatrix that only contains rindex + Args: + rindex: list + list of index to be chosen + Returns: + res: DMatrix + new DMatrix with chosen index + """ res = DMatrix(None) res.handle = ctypes.c_void_p(xglib.XGDMatrixSliceDMatrix( self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex))) return res +class CVPack: + def __init__(self, dtrain, dtest, param): + self.dtrain = dtrain + self.dtest = dtest + self.watchlist = watchlist = [ (dtrain,'train'), (dtest, 'test') ] + self.bst = Booster(param, [dtrain,dtest]) + def update(self,r): + self.bst.update(self.dtrain, r) + def eval(self,r): + return self.bst.eval_set(self.watchlist, r) + +def mknfold(dall, nfold, param, seed, weightscale=None): + """ + mk nfold list of cvpack from randidx + """ + randidx = range(dall.num_row()) + random.seed(seed) + random.shuffle(randidx) + + idxset = [] + kstep = len(randidx) / nfold + for i in range(nfold): + idxset.append(randidx[ (i*kstep) : min(len(randidx),(i+1)*kstep) ]) + + ret = [] + for k in range(nfold): + trainlst = [] + for j in range(nfold): + if j == k: + testlst = idxset[j] + else: + trainlst += idxset[j] + dtrain = dall.slice(trainlst) + dtest = dall.slice(testlst) + # rescale weight of dtrain and dtest + if weightscale != None: + dtrain.set_weight( dtrain.get_weight() * weightscale * dall.num_row() / dtrain.num_row() ) + dtest.set_weight( dtest.get_weight() * weightscale * dall.num_row() / dtest.num_row() ) + + ret.append(CVPack(dtrain, dtest, param)) + return ret + +def aggcv(rlist): + """ + aggregate cross validation results + """ + cvmap = {} + arr = rlist[0].split() + ret = arr[0] + for it in arr[1:]: + k, v = it.split(':') + cvmap[k] = [float(v)] + for line in rlist[1:]: + arr = line.split() + assert ret == arr[0] + for it in arr[1:]: + k, v = it.split(':') + cvmap[k].append(float(v)) + + for k, v in sorted(cvmap.items(), key = lambda x:x[0]): + v = np.array(v) + ret += '\t%s:%f+%f' % (k, np.mean(v), np.std(v)) + return ret + + class Booster: """learner class """ def __init__(self, params={}, cache=[], model_file = None): - """ constructor, param: """ + """ constructor + Args: + params: dict + params for boosters + cache: list + list of cache item + model_file: string + path of model file + Returns: + None + """ for d in cache: assert isinstance(d, DMatrix) dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache]) @@ -166,16 +317,30 @@ class Booster: xglib.XGBoosterSetParam( self.handle, ctypes.c_char_p(k.encode('utf-8')), ctypes.c_char_p(str(v).encode('utf-8'))) + def update(self, dtrain, it): """ update - dtrain: the training DMatrix - it: current iteration number + Args: + dtrain: DMatrix + the training DMatrix + it: int + current iteration number + Returns: + None """ assert isinstance(dtrain, DMatrix) xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle) def boost(self, dtrain, grad, hess): - """ update """ + """ update + Args: + dtrain: DMatrix + the training DMatrix + grad: list + the first order of gradient + hess: list + the second order of gradient + """ assert len(grad) == len(hess) assert isinstance(dtrain, DMatrix) xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle, @@ -183,6 +348,14 @@ class Booster: (ctypes.c_float*len(hess))(*hess), len(grad)) def eval_set(self, evals, it = 0): + """evaluates by metric + Args: + evals: list of tuple (DMatrix, string) + lists of items to be evaluated + it: int + Returns: + evals result + """ for d in evals: assert isinstance(d[0], DMatrix) assert isinstance(d[1], str) @@ -195,21 +368,46 @@ class Booster: def predict(self, data, output_margin=False): """ predict with data - data: the dmatrix storing the input - output_margin: whether output raw margin value that is untransformed + Args: + data: DMatrix + the dmatrix storing the input + output_margin: bool + whether output raw margin value that is untransformed + Returns: + numpy array of prediction """ length = ctypes.c_ulong() preds = xglib.XGBoosterPredict(self.handle, data.handle, int(output_margin), ctypes.byref(length)) return ctypes2numpy(preds, length.value, 'float32') def save_model(self, fname): - """ save model to file """ + """ save model to file + Args: + fname: string + file name of saving model + Returns: + None + """ xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8'))) def load_model(self, fname): - """load model from file""" + """load model from file + Args: + fname: string + file name of saving model + Returns: + None + """ xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) ) def dump_model(self, fo, fmap=''): - """dump model into text file""" + """dump model into text file + Args: + fo: string + file name to be dumped + fmap: string, optional + file name of feature map names + Returns: + None + """ if isinstance(fo,str): fo = open(fo,'w') need_close = True @@ -248,7 +446,17 @@ class Booster: return fmap def evaluate(bst, evals, it, feval = None): - """evaluation on eval set""" + """evaluation on eval set + Args: + bst: XGBoost object + object of XGBoost model + evals: list of tuple (DMatrix, string) + obj need to be evaluated + it: int + feval: optional + Returns: + eval result + """ if feval != None: res = '[%d]' % it for dm, evname in evals: @@ -259,8 +467,22 @@ def evaluate(bst, evals, it, feval = None): return res + + def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None): - """ train a booster with given paramaters """ + """ train a booster with given paramaters + Args: + params: dict + params of booster + dtrain: DMatrix + data to be trained + num_boost_round: int + num of round to be boosted + evals: list + list of items to be evaluated + obj: + feval: + """ bst = Booster(params, [dtrain]+[ d[0] for d in evals ] ) if obj == None: for i in range(num_boost_round): @@ -276,3 +498,27 @@ def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None if len(evals) != 0: sys.stderr.write(evaluate(bst, evals, i, feval)+'\n') return bst + +def cv(params, dtrain, num_boost_round = 10, nfold=3, evals = [], obj=None, feval=None): + """ cross validation with given paramaters + Args: + params: dict + params of booster + dtrain: DMatrix + data to be trained + num_boost_round: int + num of round to be boosted + nfold: int + folds to do cv + evals: list + list of items to be evaluated + obj: + feval: + """ + plst = list(params.items())+[('eval_metric', itm) for itm in evals] + cvfolds = mknfold(dtrain, nfold, plst, 0) + for i in range(num_boost_round): + for f in cvfolds: + f.update(i) + res = aggcv([f.eval(i) for f in cvfolds]) + sys.stderr.write(res+'\n')