This commit is contained in:
tqchen@graphlab.com 2014-08-18 10:14:34 -07:00
parent d3bfc31e6a
commit 7c068cbe46
3 changed files with 10 additions and 9 deletions

View File

@ -1,6 +1,6 @@
export CC = gcc
export CXX = g++
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fopenmp
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas
# specify tensor path
BIN = xgboost

View File

@ -44,7 +44,8 @@ class DMatrix:
self.handle = None
return
if isinstance(data, str):
self.handle = xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1)
self.handle = ctypes.c_void_p(
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1))
elif isinstance(data, scp.csr_matrix):
self.__init_from_csr(data)
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
@ -62,17 +63,17 @@ class DMatrix:
# convert data from csr matrix
def __init_from_csr(self, csr):
assert len(csr.indices) == len(csr.data)
self.handle = xglib.XGDMatrixCreateFromCSR(
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR(
(ctypes.c_ulong * len(csr.indptr))(*csr.indptr),
(ctypes.c_uint * len(csr.indices))(*csr.indices),
(ctypes.c_float * len(csr.data))(*csr.data),
len(csr.indptr), len(csr.data))
len(csr.indptr), len(csr.data)))
# convert data from numpy matrix
def __init_from_npy2d(self,mat,missing):
data = numpy.array(mat.reshape(mat.size), dtype='float32')
self.handle = xglib.XGDMatrixCreateFromMat(
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
mat.shape[0], mat.shape[1], ctypes.c_float(missing))
mat.shape[0], mat.shape[1], ctypes.c_float(missing)))
# destructor
def __del__(self):
xglib.XGDMatrixFree(self.handle)
@ -103,8 +104,8 @@ class DMatrix:
# slice the DMatrix to return a new DMatrix that only contains rindex
def slice(self, rindex):
res = DMatrix(None)
res.handle = xglib.XGDMatrixSliceDMatrix(
self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex))
res.handle = ctype.c_void_p(xglib.XGDMatrixSliceDMatrix(
self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex)))
return res
class Booster:

View File

@ -196,7 +196,7 @@ class DMatrixSimple : public DataMatrix {
/*! \brief data in the row */
std::vector<SparseBatch::Entry> row_data_;
/*! \brief magic number used to identify DMatrix */
static const int kMagic = 0xff01;
static const int kMagic = 0xffffab01;
protected:
// one batch iterator that return content in the matrix