diff --git a/python/xgboost.py b/python/xgboost.py index 580b80306..63326b0f5 100644 --- a/python/xgboost.py +++ b/python/xgboost.py @@ -1,5 +1,7 @@ # module for xgboost import ctypes +import numpy +import scipy.sparse as scp # load in xgboost library xglib = ctypes.cdll.LoadLibrary('./libxgboostpy.so') @@ -8,15 +10,57 @@ xglib = ctypes.cdll.LoadLibrary('./libxgboostpy.so') class REntry(ctypes.Structure): _fields_ = [("findex", ctypes.c_uint), ("fvalue", ctypes.c_float) ] - +# data matrix used in xgboost class DMatrix: - def __init__(self,fname = None): - self.__handle = xglib.XGDMatrixCreate(); - if fname != None: - xglib.XGDMatrixLoad(self.__handle, ctypes.c_char_p(fname), 0) + # constructor + def __init__(self, data=None, label=None): + self.handle = xglib.XGDMatrixCreate(); + if data == None: + return + if type(data) is str: + xglib.XGDMatrixLoad(self.handle, ctypes.c_char_p(data), 1) + elif type(data) is scp.csr_matrix: + self.__init_from_csr(data) + else: + try: + csr = scp.csr_matrix(data) + self.__init_from_csr(data) + except: + raise "DMatrix", "can not intialize DMatrix from"+type(data) + if label != None: + self.set_label(label) + + # convert data from csr matrix + def __init_from_csr(self,csr): + assert len(csr.indices) == len(csr.data) + xglib.XGDMatrixParseCSR( self.handle, + ( ctypes.c_ulong * len(csr.indptr) )(*csr.indptr), + ( ctypes.c_uint * len(csr.indices) )(*csr.indices), + ( ctypes.c_float * len(csr.data) )(*csr.data), + len(csr.indptr), len(csr.data) ) + # destructor def __del__(self): - xglib.XGDMatrixFree(self.__handle) - -dmata = DMatrix('xx.buffer') - + xglib.XGDMatrixFree(self.handle) + # load data from file + def load(self, fname): + xglib.XGDMatrixLoad(self.handle, ctypes.c_char_p(fname), 1) + # set label of dmatrix + def set_label(self, label): + xglib.XGDMatrixSetLabel(self.handle, (ctypes.c_float*len(label))(*label), len(label) ); + # get label from dmatrix + def get_label(self): + GetLabel = xglib.XGDMatrixGetLabel + GetLabel.restype = ctypes.POINTER( ctypes.c_float ) + length = ctypes.c_ulong() + labels = GetLabel(self.handle, ctypes.byref(length)); + return [ labels[i] for i in xrange(length.value) ] + # append a row to DMatrix + def add_row(self, row): + xglib.XGDMatrixAddRow(self.handle, (REntry*len(row))(*row), len(row) ); + +mat = DMatrix('xx.buffer') +lb = mat.get_label() +print len(lb) +mat.set_label(lb) +mat.add_row( [(1,2), (3,4)] ) diff --git a/python/xgboost_python.cpp b/python/xgboost_python.cpp index ee97c68d3..71caff1d0 100644 --- a/python/xgboost_python.cpp +++ b/python/xgboost_python.cpp @@ -18,6 +18,37 @@ namespace xgboost{ this->CacheLoad(fname, silent); init_col_ = this->data.HaveColAccess(); } + inline void AddRow( const XGEntry *data, size_t len ){ + xgboost::booster::FMatrixS &mat = this->data; + mat.row_data_.resize( mat.row_ptr_.back() + len ); + memcpy( &mat.row_data_[mat.row_ptr_.back()], data, sizeof(XGEntry)*len ); + mat.row_ptr_.push_back( mat.row_ptr_.back() + len ); + } + inline void ParseCSR( const size_t *indptr, + const unsigned *indices, + const float *data, + size_t nindptr, + size_t nelem ){ + xgboost::booster::FMatrixS &mat = this->data; + mat.row_ptr_.resize( nindptr ); + memcpy( &mat.row_ptr_[0], indptr, sizeof(size_t)*nindptr ); + mat.row_data_.resize( nelem ); + for( size_t i = 0; i < nelem; ++ i ){ + mat.row_data_[i] = XGEntry(indices[i], data[i]); + } + } + inline void SetLabel( const float *label, size_t len ){ + this->info.labels.resize( len ); + memcpy( &(this->info).labels[0], label, sizeof(float)*len ); + } + inline const float* GetLabel( size_t* len ) const{ + *len = this->info.labels.size(); + return &(this->info.labels[0]); + } + inline void InitTrain(void){ + if(!this->data.HaveColAccess()) this->data.InitData(); + utils::Assert( this->data.NumRow() == this->info.labels.size(), "DMatrix: number of labels must match number of rows in matrix"); + } }; }; }; @@ -25,17 +56,34 @@ namespace xgboost{ using namespace xgboost::python; extern "C"{ - void* XGDMatrixCreate(void){ + void* XGDMatrixCreate( void ){ return new DMatrix(); } - void XGDMatrixFree(void *handle){ + void XGDMatrixFree( void *handle ){ delete static_cast(handle); } - void XGDMatrixLoad(void *handle, const char *fname, int silent){ + void XGDMatrixLoad( void *handle, const char *fname, int silent ){ static_cast(handle)->Load(fname, silent!=0); } - void XGDMatrixSaveBinary(void *handle, const char *fname, int silent){ + void XGDMatrixSaveBinary( void *handle, const char *fname, int silent ){ static_cast(handle)->SaveBinary(fname, silent!=0); } + void XGDMatrixAddRow( void *handle, const XGEntry *data, size_t len ){ + static_cast(handle)->AddRow(data, len); + } + void XGDMatrixParseCSR( void *handle, + const size_t *indptr, + const unsigned *indices, + const float *data, + size_t nindptr, + size_t nelem ){ + static_cast(handle)->ParseCSR(indptr, indices, data, nindptr, nelem); + } + void XGDMatrixSetLabel( void *handle, const float *label, size_t len ){ + static_cast(handle)->SetLabel(label,len); + } + const float* XGDMatrixGetLabel( const void *handle, size_t* len ){ + return static_cast(handle)->GetLabel(len); + } }; diff --git a/python/xgboost_python.h b/python/xgboost_python.h index ead07200d..2869c7aeb 100644 --- a/python/xgboost_python.h +++ b/python/xgboost_python.h @@ -34,14 +34,41 @@ extern "C"{ * \param silent print statistics when saving */ void XGDMatrixSaveBinary(void *handle, const char *fname, int silent); + /*! + * \brief set matrix content from csr format + * \param handle a instance of data matrix + * \param indptr pointer to row headers + * \param indices findex + * \param data fvalue + * \param nindptr number of rows in the matix + 1 + * \param nelem number of nonzero elements in the matrix + */ + void XGDMatrixParseCSR( void *handle, + const size_t *indptr, + const unsigned *indices, + const float *data, + size_t nindptr, + size_t nelem ); + /*! + * \brief set label of the training matrix + * \param handle a instance of data matrix + * \param data array of row content + * \param len length of array + */ + void XGDMatrixSetLabel( void *handle, const float *label, size_t len ); + /*! + * \brief get label set from matrix + * \param handle a instance of data matrix + * \param len used to set result length + */ + const float* XGDMatrixGetLabel( const void *handle, size_t* len ); /*! * \brief add row * \param handle a instance of data matrix - * \param fname file name - * \return a new data matrix + * \param data array of row content + * \param len length of array */ - void XGDMatrixPush(void *handle, const XGEntry *data, int len); - + void XGDMatrixAddRow(void *handle, const XGEntry *data, size_t len); /*! * \brief create a booster */