add create from csc
This commit is contained in:
parent
2bc1d2e73a
commit
a1c6e22af9
@ -7,7 +7,6 @@
|
|||||||
#include "wrapper/xgboost_wrapper.h"
|
#include "wrapper/xgboost_wrapper.h"
|
||||||
#include "src/utils/utils.h"
|
#include "src/utils/utils.h"
|
||||||
#include "src/utils/omp.h"
|
#include "src/utils/omp.h"
|
||||||
#include "src/utils/matrix_csr.h"
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace xgboost;
|
using namespace xgboost;
|
||||||
|
|
||||||
@ -91,37 +90,25 @@ extern "C" {
|
|||||||
SEXP indices,
|
SEXP indices,
|
||||||
SEXP data) {
|
SEXP data) {
|
||||||
_WrapperBegin();
|
_WrapperBegin();
|
||||||
const int *col_ptr = INTEGER(indptr);
|
const int *p_indptr = INTEGER(indptr);
|
||||||
const int *row_index = INTEGER(indices);
|
const int *p_indices = INTEGER(indices);
|
||||||
const double *col_data = REAL(data);
|
const double *p_data = REAL(data);
|
||||||
int ncol = length(indptr) - 1;
|
int nindptr = length(indptr);
|
||||||
int ndata = length(data);
|
int ndata = length(data);
|
||||||
// transform into CSR format
|
std::vector<bst_ulong> col_ptr_(nindptr);
|
||||||
std::vector<bst_ulong> row_ptr;
|
std::vector<unsigned> indices_(ndata);
|
||||||
std::vector< std::pair<unsigned, float> > csr_data;
|
std::vector<float> data_(ndata);
|
||||||
utils::SparseCSRMBuilder<std::pair<unsigned,float>, false, bst_ulong> builder(row_ptr, csr_data);
|
|
||||||
builder.InitBudget();
|
for (int i = 0; i < nindptr; ++i) {
|
||||||
for (int i = 0; i < ncol; ++i) {
|
col_ptr_[i] = static_cast<bst_ulong>(p_indptr[i]);
|
||||||
for (int j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
|
|
||||||
builder.AddBudget(row_index[j]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
builder.InitStorage();
|
|
||||||
for (int i = 0; i < ncol; ++i) {
|
|
||||||
for (int j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
|
|
||||||
builder.PushElem(row_index[j], std::make_pair(i, col_data[j]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
utils::Assert(csr_data.size() == static_cast<size_t>(ndata), "BUG CreateFromCSC");
|
|
||||||
std::vector<float> row_data(ndata);
|
|
||||||
std::vector<unsigned> col_index(ndata);
|
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (int i = 0; i < ndata; ++i) {
|
for (int i = 0; i < ndata; ++i) {
|
||||||
col_index[i] = csr_data[i].first;
|
indices_[i] = static_cast<unsigned>(p_indices[i]);
|
||||||
row_data[i] = csr_data[i].second;
|
data_[i] = static_cast<float>(p_data[i]);
|
||||||
}
|
}
|
||||||
void *handle = XGDMatrixCreateFromCSR(BeginPtr(row_ptr), BeginPtr(col_index),
|
void *handle = XGDMatrixCreateFromCSC(BeginPtr(col_ptr_), BeginPtr(indices_),
|
||||||
BeginPtr(row_data), row_ptr.size(), ndata );
|
BeginPtr(data_), nindptr, ndata);
|
||||||
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||||
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
||||||
UNPROTECT(1);
|
UNPROTECT(1);
|
||||||
|
|||||||
@ -42,7 +42,7 @@ assert np.sum(np.abs(preds2-preds)) == 0
|
|||||||
|
|
||||||
###
|
###
|
||||||
# build dmatrix from scipy.sparse
|
# build dmatrix from scipy.sparse
|
||||||
print ('start running example of build DMatrix from scipy.sparse')
|
print ('start running example of build DMatrix from scipy.sparse CSR Matrix')
|
||||||
labels = []
|
labels = []
|
||||||
row = []; col = []; dat = []
|
row = []; col = []; dat = []
|
||||||
i = 0
|
i = 0
|
||||||
@ -54,8 +54,14 @@ for l in open('../data/agaricus.txt.train'):
|
|||||||
row.append(i); col.append(int(k)); dat.append(float(v))
|
row.append(i); col.append(int(k)); dat.append(float(v))
|
||||||
i += 1
|
i += 1
|
||||||
csr = scipy.sparse.csr_matrix( (dat, (row,col)) )
|
csr = scipy.sparse.csr_matrix( (dat, (row,col)) )
|
||||||
dtrain = xgb.DMatrix( csr )
|
dtrain = xgb.DMatrix( csr, label = labels )
|
||||||
dtrain.set_label(labels)
|
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||||
|
bst = xgb.train( param, dtrain, num_round, watchlist )
|
||||||
|
|
||||||
|
print ('start running example of build DMatrix from scipy.sparse CSC Matrix')
|
||||||
|
# we can also construct from csc matrix
|
||||||
|
csc = scipy.sparse.csc_matrix( (dat, (row,col)) )
|
||||||
|
dtrain = xgb.DMatrix(csc, label=labels)
|
||||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||||
bst = xgb.train( param, dtrain, num_round, watchlist )
|
bst = xgb.train( param, dtrain, num_round, watchlist )
|
||||||
|
|
||||||
@ -63,8 +69,7 @@ print ('start running example of build DMatrix from numpy array')
|
|||||||
# NOTE: npymat is numpy array, we will convert it into scipy.sparse.csr_matrix in internal implementation
|
# NOTE: npymat is numpy array, we will convert it into scipy.sparse.csr_matrix in internal implementation
|
||||||
# then convert to DMatrix
|
# then convert to DMatrix
|
||||||
npymat = csr.todense()
|
npymat = csr.todense()
|
||||||
dtrain = xgb.DMatrix( npymat)
|
dtrain = xgb.DMatrix(npymat, label = labels)
|
||||||
dtrain.set_label(labels)
|
|
||||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||||
bst = xgb.train( param, dtrain, num_round, watchlist )
|
bst = xgb.train( param, dtrain, num_round, watchlist )
|
||||||
|
|
||||||
|
|||||||
@ -22,6 +22,7 @@ xglib = ctypes.cdll.LoadLibrary(XGBOOST_PATH)
|
|||||||
# DMatrix functions
|
# DMatrix functions
|
||||||
xglib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p
|
xglib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p
|
||||||
xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
|
xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
|
||||||
|
xglib.XGDMatrixCreateFromCSC.restype = ctypes.c_void_p
|
||||||
xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
|
xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
|
||||||
xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
|
xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
|
||||||
xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
|
xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
|
||||||
@ -66,6 +67,8 @@ class DMatrix:
|
|||||||
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 0))
|
xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 0))
|
||||||
elif isinstance(data, scp.csr_matrix):
|
elif isinstance(data, scp.csr_matrix):
|
||||||
self.__init_from_csr(data)
|
self.__init_from_csr(data)
|
||||||
|
elif isinstance(data, scp.csc_matrix):
|
||||||
|
self.__init_from_csc(data)
|
||||||
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
|
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
|
||||||
self.__init_from_npy2d(data, missing)
|
self.__init_from_npy2d(data, missing)
|
||||||
else:
|
else:
|
||||||
@ -88,6 +91,15 @@ class DMatrix:
|
|||||||
(ctypes.c_float * len(csr.data))(*csr.data),
|
(ctypes.c_float * len(csr.data))(*csr.data),
|
||||||
len(csr.indptr), len(csr.data)))
|
len(csr.indptr), len(csr.data)))
|
||||||
|
|
||||||
|
def __init_from_csc(self, csc):
|
||||||
|
"""convert data from csr matrix"""
|
||||||
|
assert len(csc.indices) == len(csc.data)
|
||||||
|
self.handle = ctypes.c_void_p(xglib.XGDMatrixCreateFromCSC(
|
||||||
|
(ctypes.c_ulong * len(csc.indptr))(*csc.indptr),
|
||||||
|
(ctypes.c_uint * len(csc.indices))(*csc.indices),
|
||||||
|
(ctypes.c_float * len(csc.data))(*csc.data),
|
||||||
|
len(csc.indptr), len(csc.data)))
|
||||||
|
|
||||||
def __init_from_npy2d(self,mat,missing):
|
def __init_from_npy2d(self,mat,missing):
|
||||||
"""convert data from numpy matrix"""
|
"""convert data from numpy matrix"""
|
||||||
data = numpy.array(mat.reshape(mat.size), dtype='float32')
|
data = numpy.array(mat.reshape(mat.size), dtype='float32')
|
||||||
|
|||||||
@ -14,6 +14,7 @@ using namespace std;
|
|||||||
#include "../src/learner/learner-inl.hpp"
|
#include "../src/learner/learner-inl.hpp"
|
||||||
#include "../src/io/io.h"
|
#include "../src/io/io.h"
|
||||||
#include "../src/utils/utils.h"
|
#include "../src/utils/utils.h"
|
||||||
|
#include "../src/utils/matrix_csr.h"
|
||||||
#include "../src/io/simple_dmatrix-inl.hpp"
|
#include "../src/io/simple_dmatrix-inl.hpp"
|
||||||
|
|
||||||
using namespace xgboost;
|
using namespace xgboost;
|
||||||
@ -102,6 +103,31 @@ extern "C"{
|
|||||||
mat.info.info.num_row = nindptr - 1;
|
mat.info.info.num_row = nindptr - 1;
|
||||||
return p_mat;
|
return p_mat;
|
||||||
}
|
}
|
||||||
|
XGB_DLL void* XGDMatrixCreateFromCSC(const bst_ulong *col_ptr,
|
||||||
|
const unsigned *indices,
|
||||||
|
const float *data,
|
||||||
|
bst_ulong nindptr,
|
||||||
|
bst_ulong nelem) {
|
||||||
|
DMatrixSimple *p_mat = new DMatrixSimple();
|
||||||
|
DMatrixSimple &mat = *p_mat;
|
||||||
|
utils::SparseCSRMBuilder<RowBatch::Entry, false> builder(mat.row_ptr_, mat.row_data_);
|
||||||
|
builder.InitBudget();
|
||||||
|
bst_ulong ncol = nindptr - 1;
|
||||||
|
for (bst_ulong i = 0; i < ncol; ++i) {
|
||||||
|
for (unsigned j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
|
||||||
|
builder.AddBudget(indices[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.InitStorage();
|
||||||
|
for (bst_ulong i = 0; i < ncol; ++i) {
|
||||||
|
for (unsigned j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
|
||||||
|
builder.PushElem(indices[j], RowBatch::Entry(static_cast<bst_uint>(i), data[j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mat.info.info.num_row = mat.row_ptr_.size() - 1;
|
||||||
|
mat.info.info.num_col = static_cast<size_t>(ncol);
|
||||||
|
return p_mat;
|
||||||
|
}
|
||||||
void* XGDMatrixCreateFromMat(const float *data,
|
void* XGDMatrixCreateFromMat(const float *data,
|
||||||
bst_ulong nrow,
|
bst_ulong nrow,
|
||||||
bst_ulong ncol,
|
bst_ulong ncol,
|
||||||
|
|||||||
@ -36,6 +36,20 @@ extern "C" {
|
|||||||
const float *data,
|
const float *data,
|
||||||
bst_ulong nindptr,
|
bst_ulong nindptr,
|
||||||
bst_ulong nelem);
|
bst_ulong nelem);
|
||||||
|
/*!
|
||||||
|
* \brief create a matrix content from CSC format
|
||||||
|
* \param col_ptr pointer to col headers
|
||||||
|
* \param indices findex
|
||||||
|
* \param data fvalue
|
||||||
|
* \param nindptr number of rows in the matix + 1
|
||||||
|
* \param nelem number of nonzero elements in the matrix
|
||||||
|
* \return created dmatrix
|
||||||
|
*/
|
||||||
|
XGB_DLL void* XGDMatrixCreateFromCSC(const bst_ulong *col_ptr,
|
||||||
|
const unsigned *indices,
|
||||||
|
const float *data,
|
||||||
|
bst_ulong nindptr,
|
||||||
|
bst_ulong nelem);
|
||||||
/*!
|
/*!
|
||||||
* \brief create matrix content from dense matrix
|
* \brief create matrix content from dense matrix
|
||||||
* \param data pointer to the data space
|
* \param data pointer to the data space
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user