good
This commit is contained in:
parent
59939d0b14
commit
ccd037292d
@ -1,5 +1,7 @@
|
|||||||
# module for xgboost
|
# module for xgboost
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import numpy
|
||||||
|
import scipy.sparse as scp
|
||||||
|
|
||||||
# load in xgboost library
|
# load in xgboost library
|
||||||
xglib = ctypes.cdll.LoadLibrary('./libxgboostpy.so')
|
xglib = ctypes.cdll.LoadLibrary('./libxgboostpy.so')
|
||||||
@ -8,15 +10,57 @@ xglib = ctypes.cdll.LoadLibrary('./libxgboostpy.so')
|
|||||||
class REntry(ctypes.Structure):
|
class REntry(ctypes.Structure):
|
||||||
_fields_ = [("findex", ctypes.c_uint), ("fvalue", ctypes.c_float) ]
|
_fields_ = [("findex", ctypes.c_uint), ("fvalue", ctypes.c_float) ]
|
||||||
|
|
||||||
|
# data matrix used in xgboost
|
||||||
class DMatrix:
|
class DMatrix:
|
||||||
def __init__(self,fname = None):
|
# constructor
|
||||||
self.__handle = xglib.XGDMatrixCreate();
|
def __init__(self, data=None, label=None):
|
||||||
if fname != None:
|
self.handle = xglib.XGDMatrixCreate();
|
||||||
xglib.XGDMatrixLoad(self.__handle, ctypes.c_char_p(fname), 0)
|
if data == None:
|
||||||
|
return
|
||||||
|
if type(data) is str:
|
||||||
|
xglib.XGDMatrixLoad(self.handle, ctypes.c_char_p(data), 1)
|
||||||
|
elif type(data) is scp.csr_matrix:
|
||||||
|
self.__init_from_csr(data)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
csr = scp.csr_matrix(data)
|
||||||
|
self.__init_from_csr(data)
|
||||||
|
except:
|
||||||
|
raise "DMatrix", "can not intialize DMatrix from"+type(data)
|
||||||
|
if label != None:
|
||||||
|
self.set_label(label)
|
||||||
|
|
||||||
|
# convert data from csr matrix
|
||||||
|
def __init_from_csr(self,csr):
|
||||||
|
assert len(csr.indices) == len(csr.data)
|
||||||
|
xglib.XGDMatrixParseCSR( self.handle,
|
||||||
|
( ctypes.c_ulong * len(csr.indptr) )(*csr.indptr),
|
||||||
|
( ctypes.c_uint * len(csr.indices) )(*csr.indices),
|
||||||
|
( ctypes.c_float * len(csr.data) )(*csr.data),
|
||||||
|
len(csr.indptr), len(csr.data) )
|
||||||
|
# destructor
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
xglib.XGDMatrixFree(self.__handle)
|
xglib.XGDMatrixFree(self.handle)
|
||||||
|
# load data from file
|
||||||
dmata = DMatrix('xx.buffer')
|
def load(self, fname):
|
||||||
|
xglib.XGDMatrixLoad(self.handle, ctypes.c_char_p(fname), 1)
|
||||||
|
# set label of dmatrix
|
||||||
|
def set_label(self, label):
|
||||||
|
xglib.XGDMatrixSetLabel(self.handle, (ctypes.c_float*len(label))(*label), len(label) );
|
||||||
|
# get label from dmatrix
|
||||||
|
def get_label(self):
|
||||||
|
GetLabel = xglib.XGDMatrixGetLabel
|
||||||
|
GetLabel.restype = ctypes.POINTER( ctypes.c_float )
|
||||||
|
length = ctypes.c_ulong()
|
||||||
|
labels = GetLabel(self.handle, ctypes.byref(length));
|
||||||
|
return [ labels[i] for i in xrange(length.value) ]
|
||||||
|
# append a row to DMatrix
|
||||||
|
def add_row(self, row):
|
||||||
|
xglib.XGDMatrixAddRow(self.handle, (REntry*len(row))(*row), len(row) );
|
||||||
|
|
||||||
|
|
||||||
|
mat = DMatrix('xx.buffer')
|
||||||
|
lb = mat.get_label()
|
||||||
|
print len(lb)
|
||||||
|
mat.set_label(lb)
|
||||||
|
mat.add_row( [(1,2), (3,4)] )
|
||||||
|
|||||||
@ -18,6 +18,37 @@ namespace xgboost{
|
|||||||
this->CacheLoad(fname, silent);
|
this->CacheLoad(fname, silent);
|
||||||
init_col_ = this->data.HaveColAccess();
|
init_col_ = this->data.HaveColAccess();
|
||||||
}
|
}
|
||||||
|
inline void AddRow( const XGEntry *data, size_t len ){
|
||||||
|
xgboost::booster::FMatrixS &mat = this->data;
|
||||||
|
mat.row_data_.resize( mat.row_ptr_.back() + len );
|
||||||
|
memcpy( &mat.row_data_[mat.row_ptr_.back()], data, sizeof(XGEntry)*len );
|
||||||
|
mat.row_ptr_.push_back( mat.row_ptr_.back() + len );
|
||||||
|
}
|
||||||
|
inline void ParseCSR( const size_t *indptr,
|
||||||
|
const unsigned *indices,
|
||||||
|
const float *data,
|
||||||
|
size_t nindptr,
|
||||||
|
size_t nelem ){
|
||||||
|
xgboost::booster::FMatrixS &mat = this->data;
|
||||||
|
mat.row_ptr_.resize( nindptr );
|
||||||
|
memcpy( &mat.row_ptr_[0], indptr, sizeof(size_t)*nindptr );
|
||||||
|
mat.row_data_.resize( nelem );
|
||||||
|
for( size_t i = 0; i < nelem; ++ i ){
|
||||||
|
mat.row_data_[i] = XGEntry(indices[i], data[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inline void SetLabel( const float *label, size_t len ){
|
||||||
|
this->info.labels.resize( len );
|
||||||
|
memcpy( &(this->info).labels[0], label, sizeof(float)*len );
|
||||||
|
}
|
||||||
|
inline const float* GetLabel( size_t* len ) const{
|
||||||
|
*len = this->info.labels.size();
|
||||||
|
return &(this->info.labels[0]);
|
||||||
|
}
|
||||||
|
inline void InitTrain(void){
|
||||||
|
if(!this->data.HaveColAccess()) this->data.InitData();
|
||||||
|
utils::Assert( this->data.NumRow() == this->info.labels.size(), "DMatrix: number of labels must match number of rows in matrix");
|
||||||
|
}
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@ -25,17 +56,34 @@ namespace xgboost{
|
|||||||
using namespace xgboost::python;
|
using namespace xgboost::python;
|
||||||
|
|
||||||
extern "C"{
|
extern "C"{
|
||||||
void* XGDMatrixCreate(void){
|
void* XGDMatrixCreate( void ){
|
||||||
return new DMatrix();
|
return new DMatrix();
|
||||||
}
|
}
|
||||||
void XGDMatrixFree(void *handle){
|
void XGDMatrixFree( void *handle ){
|
||||||
delete static_cast<DMatrix*>(handle);
|
delete static_cast<DMatrix*>(handle);
|
||||||
}
|
}
|
||||||
void XGDMatrixLoad(void *handle, const char *fname, int silent){
|
void XGDMatrixLoad( void *handle, const char *fname, int silent ){
|
||||||
static_cast<DMatrix*>(handle)->Load(fname, silent!=0);
|
static_cast<DMatrix*>(handle)->Load(fname, silent!=0);
|
||||||
}
|
}
|
||||||
void XGDMatrixSaveBinary(void *handle, const char *fname, int silent){
|
void XGDMatrixSaveBinary( void *handle, const char *fname, int silent ){
|
||||||
static_cast<DMatrix*>(handle)->SaveBinary(fname, silent!=0);
|
static_cast<DMatrix*>(handle)->SaveBinary(fname, silent!=0);
|
||||||
}
|
}
|
||||||
|
void XGDMatrixAddRow( void *handle, const XGEntry *data, size_t len ){
|
||||||
|
static_cast<DMatrix*>(handle)->AddRow(data, len);
|
||||||
|
}
|
||||||
|
void XGDMatrixParseCSR( void *handle,
|
||||||
|
const size_t *indptr,
|
||||||
|
const unsigned *indices,
|
||||||
|
const float *data,
|
||||||
|
size_t nindptr,
|
||||||
|
size_t nelem ){
|
||||||
|
static_cast<DMatrix*>(handle)->ParseCSR(indptr, indices, data, nindptr, nelem);
|
||||||
|
}
|
||||||
|
void XGDMatrixSetLabel( void *handle, const float *label, size_t len ){
|
||||||
|
static_cast<DMatrix*>(handle)->SetLabel(label,len);
|
||||||
|
}
|
||||||
|
const float* XGDMatrixGetLabel( const void *handle, size_t* len ){
|
||||||
|
return static_cast<const DMatrix*>(handle)->GetLabel(len);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -34,14 +34,41 @@ extern "C"{
|
|||||||
* \param silent print statistics when saving
|
* \param silent print statistics when saving
|
||||||
*/
|
*/
|
||||||
void XGDMatrixSaveBinary(void *handle, const char *fname, int silent);
|
void XGDMatrixSaveBinary(void *handle, const char *fname, int silent);
|
||||||
|
/*!
|
||||||
|
* \brief set matrix content from csr format
|
||||||
|
* \param handle a instance of data matrix
|
||||||
|
* \param indptr pointer to row headers
|
||||||
|
* \param indices findex
|
||||||
|
* \param data fvalue
|
||||||
|
* \param nindptr number of rows in the matix + 1
|
||||||
|
* \param nelem number of nonzero elements in the matrix
|
||||||
|
*/
|
||||||
|
void XGDMatrixParseCSR( void *handle,
|
||||||
|
const size_t *indptr,
|
||||||
|
const unsigned *indices,
|
||||||
|
const float *data,
|
||||||
|
size_t nindptr,
|
||||||
|
size_t nelem );
|
||||||
|
/*!
|
||||||
|
* \brief set label of the training matrix
|
||||||
|
* \param handle a instance of data matrix
|
||||||
|
* \param data array of row content
|
||||||
|
* \param len length of array
|
||||||
|
*/
|
||||||
|
void XGDMatrixSetLabel( void *handle, const float *label, size_t len );
|
||||||
|
/*!
|
||||||
|
* \brief get label set from matrix
|
||||||
|
* \param handle a instance of data matrix
|
||||||
|
* \param len used to set result length
|
||||||
|
*/
|
||||||
|
const float* XGDMatrixGetLabel( const void *handle, size_t* len );
|
||||||
/*!
|
/*!
|
||||||
* \brief add row
|
* \brief add row
|
||||||
* \param handle a instance of data matrix
|
* \param handle a instance of data matrix
|
||||||
* \param fname file name
|
* \param data array of row content
|
||||||
* \return a new data matrix
|
* \param len length of array
|
||||||
*/
|
*/
|
||||||
void XGDMatrixPush(void *handle, const XGEntry *data, int len);
|
void XGDMatrixAddRow(void *handle, const XGEntry *data, size_t len);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief create a booster
|
* \brief create a booster
|
||||||
*/
|
*/
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user