diff --git a/python/xgboost.py b/python/xgboost.py index 922ca085d..d37566065 100644 --- a/python/xgboost.py +++ b/python/xgboost.py @@ -22,6 +22,13 @@ xglib.XGDMatrixGetLabel.restype = ctypes.POINTER( ctypes.c_float ) xglib.XGDMatrixGetRow.restype = ctypes.POINTER( REntry ) xglib.XGBoosterPredict.restype = ctypes.POINTER( ctypes.c_float ) +def ctypes2numpy( cptr, length ): + # convert a ctypes pointer array to numpy + assert isinstance( cptr, ctypes.POINTER( ctypes.c_float ) ) + res = numpy.zeros( length, dtype='float32' ) + assert ctypes.memmove( res.ctypes.data, cptr, length * res.strides[0] ) + return res + # data matrix used in xgboost class DMatrix: # constructor @@ -73,7 +80,7 @@ class DMatrix: def get_label(self): length = ctypes.c_ulong() labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length)) - return numpy.array( [labels[i] for i in xrange(length.value)] ) + return ctypes2numpy( labels, length.value ); # clear everything def clear(self): xglib.XGDMatrixClear(self.handle) @@ -138,7 +145,7 @@ class Booster: def predict(self, data, bst_group = -1): length = ctypes.c_ulong() preds = xglib.XGBoosterPredict( self.handle, data.handle, ctypes.byref(length), bst_group) - return numpy.array( [ preds[i] for i in xrange(length.value)]) + return ctypes2numpy( preds, length.value ) def save_model(self, fname): """ save model to file """ xglib.XGBoosterSaveModel( self.handle, ctypes.c_char_p(fname) )