finish matrix

This commit is contained in:
tqchen
2014-05-03 17:12:25 -07:00
parent 5bab27cfa6
commit 20de7f8f97
3 changed files with 124 additions and 28 deletions

View File

@@ -1,15 +1,20 @@
# module for xgboost
import ctypes
import numpy
# optinally have scipy sparse, though not necessary
import scipy.sparse as scp
# load in xgboost library
xglib = ctypes.cdll.LoadLibrary('./libxgboostpy.so')
# entry type of sparse matrix
class REntry(ctypes.Structure):
_fields_ = [("findex", ctypes.c_uint), ("fvalue", ctypes.c_float) ]
# load in xgboost library
xglib = ctypes.cdll.LoadLibrary('./libxgboostpy.so')
xglib.XGDMatrixCreate.restype = ctypes.c_void_p
xglib.XGDMatrixNumRow.restype = ctypes.c_ulong
xglib.XGDMatrixGetLabel.restype = ctypes.POINTER( ctypes.c_float )
xglib.XGDMatrixGetRow.restype = ctypes.POINTER( REntry )
# data matrix used in xgboost
class DMatrix:
# constructor
@@ -40,27 +45,37 @@ class DMatrix:
len(csr.indptr), len(csr.data) )
# destructor
def __del__(self):
xglib.XGDMatrixFree(self.handle)
xglib.XGDMatrixFree(self.handle)
# load data from file
def load(self, fname):
xglib.XGDMatrixLoad(self.handle, ctypes.c_char_p(fname), 1)
def load(self, fname, silent=True):
xglib.XGDMatrixLoad(self.handle, ctypes.c_char_p(fname), int(silent))
# load data from file
def save_binary(self, fname, silent=True):
xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname), int(silent))
# set label of dmatrix
def set_label(self, label):
xglib.XGDMatrixSetLabel(self.handle, (ctypes.c_float*len(label))(*label), len(label) );
# get label from dmatrix
def get_label(self):
GetLabel = xglib.XGDMatrixGetLabel
GetLabel.restype = ctypes.POINTER( ctypes.c_float )
length = ctypes.c_ulong()
labels = GetLabel(self.handle, ctypes.byref(length));
labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length));
return [ labels[i] for i in xrange(length.value) ]
# clear everything
def clear(self):
xglib.XGDMatrixClear(self.handle)
def num_row(self):
return xglib.XGDMatrixNumRow(self.handle)
# append a row to DMatrix
def add_row(self, row):
xglib.XGDMatrixAddRow(self.handle, (REntry*len(row))(*row), len(row) );
def add_row(self, row, label):
xglib.XGDMatrixAddRow(self.handle, (REntry*len(row))(*row), len(row), label )
# get n-throw from DMatrix
def __getitem__(self, ridx):
length = ctypes.c_ulong()
row = xglib.XGDMatrixGetRow(self.handle, ridx, ctypes.byref(length) );
return [ (int(row[i].findex),row[i].fvalue) for i in xrange(length.value) ]
mat = DMatrix('xx.buffer')
lb = mat.get_label()
print len(lb)
mat.set_label(lb)
mat.add_row( [(1,2), (3,4)] )
print mat.num_row()
mat.clear()