faster convert to numpy array
This commit is contained in:
parent
a0c0fbbb61
commit
2ccd28339e
@ -22,6 +22,13 @@ xglib.XGDMatrixGetLabel.restype = ctypes.POINTER( ctypes.c_float )
|
|||||||
xglib.XGDMatrixGetRow.restype = ctypes.POINTER( REntry )
|
xglib.XGDMatrixGetRow.restype = ctypes.POINTER( REntry )
|
||||||
xglib.XGBoosterPredict.restype = ctypes.POINTER( ctypes.c_float )
|
xglib.XGBoosterPredict.restype = ctypes.POINTER( ctypes.c_float )
|
||||||
|
|
||||||
|
def ctypes2numpy( cptr, length ):
|
||||||
|
# convert a ctypes pointer array to numpy
|
||||||
|
assert isinstance( cptr, ctypes.POINTER( ctypes.c_float ) )
|
||||||
|
res = numpy.zeros( length, dtype='float32' )
|
||||||
|
assert ctypes.memmove( res.ctypes.data, cptr, length * res.strides[0] )
|
||||||
|
return res
|
||||||
|
|
||||||
# data matrix used in xgboost
|
# data matrix used in xgboost
|
||||||
class DMatrix:
|
class DMatrix:
|
||||||
# constructor
|
# constructor
|
||||||
@ -73,7 +80,7 @@ class DMatrix:
|
|||||||
def get_label(self):
|
def get_label(self):
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length))
|
labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length))
|
||||||
return numpy.array( [labels[i] for i in xrange(length.value)] )
|
return ctypes2numpy( labels, length.value );
|
||||||
# clear everything
|
# clear everything
|
||||||
def clear(self):
|
def clear(self):
|
||||||
xglib.XGDMatrixClear(self.handle)
|
xglib.XGDMatrixClear(self.handle)
|
||||||
@ -138,7 +145,7 @@ class Booster:
|
|||||||
def predict(self, data, bst_group = -1):
|
def predict(self, data, bst_group = -1):
|
||||||
length = ctypes.c_ulong()
|
length = ctypes.c_ulong()
|
||||||
preds = xglib.XGBoosterPredict( self.handle, data.handle, ctypes.byref(length), bst_group)
|
preds = xglib.XGBoosterPredict( self.handle, data.handle, ctypes.byref(length), bst_group)
|
||||||
return numpy.array( [ preds[i] for i in xrange(length.value)])
|
return ctypes2numpy( preds, length.value )
|
||||||
def save_model(self, fname):
|
def save_model(self, fname):
|
||||||
""" save model to file """
|
""" save model to file """
|
||||||
xglib.XGBoosterSaveModel( self.handle, ctypes.c_char_p(fname) )
|
xglib.XGBoosterSaveModel( self.handle, ctypes.c_char_p(fname) )
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user