This commit is contained in:
tqchen@graphlab.com 2014-08-18 10:14:34 -07:00
parent d3bfc31e6a
commit 7c068cbe46
3 changed files with 10 additions and 9 deletions

View File

@ -1,6 +1,6 @@
export CC = gcc export CC = gcc
export CXX = g++ export CXX = g++
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -fopenmp export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas
# specify tensor path # specify tensor path
BIN = xgboost BIN = xgboost

View File

@ -44,7 +44,8 @@ class DMatrix:
self.handle = None self.handle = None
return return
if isinstance(data, str): if isinstance(data, str):
self.handle = xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1) self.handle = ctypes.c_void_p(
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1))
elif isinstance(data, scp.csr_matrix): elif isinstance(data, scp.csr_matrix):
self.__init_from_csr(data) self.__init_from_csr(data)
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2: elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
@ -62,17 +63,17 @@ class DMatrix:
# convert data from csr matrix # convert data from csr matrix
def __init_from_csr(self, csr): def __init_from_csr(self, csr):
assert len(csr.indices) == len(csr.data) assert len(csr.indices) == len(csr.data)
self.handle = xglib.XGDMatrixCreateFromCSR( self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSR(
(ctypes.c_ulong * len(csr.indptr))(*csr.indptr), (ctypes.c_ulong * len(csr.indptr))(*csr.indptr),
(ctypes.c_uint * len(csr.indices))(*csr.indices), (ctypes.c_uint * len(csr.indices))(*csr.indices),
(ctypes.c_float * len(csr.data))(*csr.data), (ctypes.c_float * len(csr.data))(*csr.data),
len(csr.indptr), len(csr.data)) len(csr.indptr), len(csr.data)))
# convert data from numpy matrix # convert data from numpy matrix
def __init_from_npy2d(self,mat,missing): def __init_from_npy2d(self,mat,missing):
data = numpy.array(mat.reshape(mat.size), dtype='float32') data = numpy.array(mat.reshape(mat.size), dtype='float32')
self.handle = xglib.XGDMatrixCreateFromMat( self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromMat(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
mat.shape[0], mat.shape[1], ctypes.c_float(missing)) mat.shape[0], mat.shape[1], ctypes.c_float(missing)))
# destructor # destructor
def __del__(self): def __del__(self):
xglib.XGDMatrixFree(self.handle) xglib.XGDMatrixFree(self.handle)
@ -103,8 +104,8 @@ class DMatrix:
# slice the DMatrix to return a new DMatrix that only contains rindex # slice the DMatrix to return a new DMatrix that only contains rindex
def slice(self, rindex): def slice(self, rindex):
res = DMatrix(None) res = DMatrix(None)
res.handle = xglib.XGDMatrixSliceDMatrix( res.handle = ctype.c_void_p(xglib.XGDMatrixSliceDMatrix(
self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex)) self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex)))
return res return res
class Booster: class Booster:

View File

@ -196,7 +196,7 @@ class DMatrixSimple : public DataMatrix {
/*! \brief data in the row */ /*! \brief data in the row */
std::vector<SparseBatch::Entry> row_data_; std::vector<SparseBatch::Entry> row_data_;
/*! \brief magic number used to identify DMatrix */ /*! \brief magic number used to identify DMatrix */
static const int kMagic = 0xff01; static const int kMagic = 0xffffab01;
protected: protected:
// one batch iterator that return content in the matrix // one batch iterator that return content in the matrix