move ncol, row to booster, add set/get uint info

This commit is contained in:
tqchen 2014-08-24 17:19:22 -07:00
parent 19447cdb12
commit da75f8f1a4
3 changed files with 61 additions and 24 deletions

View File

@ -19,6 +19,7 @@ xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float) xglib.XGDMatrixGetFloatInfo.restype = ctypes.POINTER(ctypes.c_float)
xglib.XGDMatrixGetUIntInfo.restype = ctypes.POINTER(ctypes.c_uint)
xglib.XGDMatrixNumRow.restype = ctypes.c_ulong xglib.XGDMatrixNumRow.restype = ctypes.c_ulong
xglib.XGBoosterCreate.restype = ctypes.c_void_p xglib.XGBoosterCreate.restype = ctypes.c_void_p
@ -27,10 +28,10 @@ xglib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p) xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
def ctypes2numpy(cptr, length): def ctypes2numpy(cptr, length, dtype):
# convert a ctypes pointer array to numpy # convert a ctypes pointer array to numpy
assert isinstance(cptr, ctypes.POINTER(ctypes.c_float)) assert isinstance(cptr, ctypes.POINTER(ctypes.c_float))
res = numpy.zeros(length, dtype='float32') res = numpy.zeros(length, dtype=dtype)
assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]) assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0])
return res return res
@ -76,23 +77,31 @@ class DMatrix:
# destructor # destructor
def __del__(self): def __del__(self):
xglib.XGDMatrixFree(self.handle) xglib.XGDMatrixFree(self.handle)
def __get_float_info(self, field): def get_float_info(self, field):
length = ctypes.c_ulong() length = ctypes.c_ulong()
ret = xglib.XGDMatrixGetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), ret = xglib.XGDMatrixGetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
ctypes.byref(length)) ctypes.byref(length))
return ctypes2numpy(ret, length.value) return ctypes2numpy(ret, length.value, 'float32')
def __set_float_info(self, field, data): def get_uint_info(self, field):
length = ctypes.c_ulong()
ret = xglib.XGDMatrixGetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
ctypes.byref(length))
return ctypes2numpy(ret, length.value, 'uint32')
def set_float_info(self, field, data):
xglib.XGDMatrixSetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')), xglib.XGDMatrixSetFloatInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
(ctypes.c_float*len(data))(*data), len(data)) (ctypes.c_float*len(data))(*data), len(data))
def set_uint_info(self, field, data):
xglib.XGDMatrixSetUIntInfo(self.handle, ctypes.c_char_p(field.encode('utf-8')),
(ctypes.c_uint*len(data))(*data), len(data))
# load data from file # load data from file
def save_binary(self, fname, silent=True): def save_binary(self, fname, silent=True):
xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent)) xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent))
# set label of dmatrix # set label of dmatrix
def set_label(self, label): def set_label(self, label):
self.__set_float_info('label', label) self.set_float_info('label', label)
# set weight of each instances # set weight of each instances
def set_weight(self, weight): def set_weight(self, weight):
self.__set_float_info('weight', weight) self.set_float_info('weight', weight)
# set initialized margin prediction # set initialized margin prediction
def set_base_margin(self, margin): def set_base_margin(self, margin):
""" """
@ -103,19 +112,19 @@ class DMatrix:
e.g. for logistic regression: need to put in value before logistic transformation e.g. for logistic regression: need to put in value before logistic transformation
see also example/demo.py see also example/demo.py
""" """
self.__set_float_info('base_margin', margin) self.set_float_info('base_margin', margin)
# set group size of dmatrix, used for rank # set group size of dmatrix, used for rank
def set_group(self, group): def set_group(self, group):
xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group)) xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group))
# get label from dmatrix # get label from dmatrix
def get_label(self): def get_label(self):
return self.__get_float_info('label') return self.get_float_info('label')
# get weight from dmatrix # get weight from dmatrix
def get_weight(self): def get_weight(self):
return self.__get_float_info('weight') return self.get_float_info('weight')
# get base_margin from dmatrix # get base_margin from dmatrix
def get_base_margin(self): def get_base_margin(self):
return self.__get_float_info('base_margin') return self.get_float_info('base_margin')
def num_row(self): def num_row(self):
return xglib.XGDMatrixNumRow(self.handle) return xglib.XGDMatrixNumRow(self.handle)
# slice the DMatrix to return a new DMatrix that only contains rindex # slice the DMatrix to return a new DMatrix that only contains rindex
@ -189,7 +198,7 @@ class Booster:
length = ctypes.c_ulong() length = ctypes.c_ulong()
preds = xglib.XGBoosterPredict(self.handle, data.handle, preds = xglib.XGBoosterPredict(self.handle, data.handle,
int(output_margin), ctypes.byref(length)) int(output_margin), ctypes.byref(length))
return ctypes2numpy(preds, length.value) return ctypes2numpy(preds, length.value, 'float32')
def save_model(self, fname): def save_model(self, fname):
""" save model to file """ """ save model to file """
xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8'))) xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8')))

View File

@ -88,10 +88,10 @@ extern "C"{
mat.row_data_.resize(nelem); mat.row_data_.resize(nelem);
for (size_t i = 0; i < nelem; ++i) { for (size_t i = 0; i < nelem; ++i) {
mat.row_data_[i] = SparseBatch::Entry(indices[i], data[i]); mat.row_data_[i] = SparseBatch::Entry(indices[i], data[i]);
mat.info.num_col = std::max(mat.info.num_col, mat.info.info.num_col = std::max(mat.info.info.num_col,
static_cast<size_t>(indices[i]+1)); static_cast<size_t>(indices[i]+1));
} }
mat.info.num_row = nindptr - 1; mat.info.info.num_row = nindptr - 1;
return p_mat; return p_mat;
} }
void* XGDMatrixCreateFromMat(const float *data, void* XGDMatrixCreateFromMat(const float *data,
@ -100,8 +100,8 @@ extern "C"{
float missing) { float missing) {
DMatrixSimple *p_mat = new DMatrixSimple(); DMatrixSimple *p_mat = new DMatrixSimple();
DMatrixSimple &mat = *p_mat; DMatrixSimple &mat = *p_mat;
mat.info.num_row = nrow; mat.info.info.num_row = nrow;
mat.info.num_col = ncol; mat.info.info.num_col = ncol;
for (size_t i = 0; i < nrow; ++i, data += ncol) { for (size_t i = 0; i < nrow; ++i, data += ncol) {
size_t nelem = 0; size_t nelem = 0;
for (size_t j = 0; j < ncol; ++j) { for (size_t j = 0; j < ncol; ++j) {
@ -130,8 +130,8 @@ extern "C"{
utils::Check(src.info.group_ptr.size() == 0, utils::Check(src.info.group_ptr.size() == 0,
"slice does not support group structure"); "slice does not support group structure");
ret.Clear(); ret.Clear();
ret.info.num_row = len; ret.info.info.num_row = len;
ret.info.num_col = src.info.num_col; ret.info.info.num_col = src.info.num_col();
utils::IIterator<SparseBatch> *iter = src.fmat.RowIterator(); utils::IIterator<SparseBatch> *iter = src.fmat.RowIterator();
iter->BeforeFirst(); iter->BeforeFirst();
@ -165,10 +165,16 @@ extern "C"{
} }
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *info, size_t len) { void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *info, size_t len) {
std::vector<float> &vec = std::vector<float> &vec =
static_cast<DataMatrix*>(handle)->info.GetInfo(field); static_cast<DataMatrix*>(handle)->info.GetFloatInfo(field);
vec.resize(len); vec.resize(len);
memcpy(&vec[0], info, sizeof(float) * len); memcpy(&vec[0], info, sizeof(float) * len);
} }
void XGDMatrixSetUIntInfo(void *handle, const char *field, const unsigned *info, size_t len) {
std::vector<unsigned> &vec =
static_cast<DataMatrix*>(handle)->info.GetUIntInfo(field);
vec.resize(len);
memcpy(&vec[0], info, sizeof(unsigned) * len);
}
void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len) { void XGDMatrixSetGroup(void *handle, const unsigned *group, size_t len) {
DataMatrix *pmat = static_cast<DataMatrix*>(handle); DataMatrix *pmat = static_cast<DataMatrix*>(handle);
pmat->info.group_ptr.resize(len + 1); pmat->info.group_ptr.resize(len + 1);
@ -179,12 +185,18 @@ extern "C"{
} }
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* len) { const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* len) {
const std::vector<float> &vec = const std::vector<float> &vec =
static_cast<const DataMatrix*>(handle)->info.GetInfo(field); static_cast<const DataMatrix*>(handle)->info.GetFloatInfo(field);
*len = vec.size();
return &vec[0];
}
const unsigned* XGDMatrixGetUIntInfo(const void *handle, const char *field, size_t* len) {
const std::vector<unsigned> &vec =
static_cast<const DataMatrix*>(handle)->info.GetUIntInfo(field);
*len = vec.size(); *len = vec.size();
return &vec[0]; return &vec[0];
} }
size_t XGDMatrixNumRow(const void *handle) { size_t XGDMatrixNumRow(const void *handle) {
return static_cast<const DataMatrix*>(handle)->info.num_row; return static_cast<const DataMatrix*>(handle)->info.num_row();
} }
// xgboost implementation // xgboost implementation

View File

@ -69,6 +69,14 @@ extern "C" {
* \param len length of array * \param len length of array
*/ */
void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *array, size_t len); void XGDMatrixSetFloatInfo(void *handle, const char *field, const float *array, size_t len);
/*!
* \brief set uint32 vector to a content in info
* \param handle a instance of data matrix
* \param field field name
* \param array pointer to float vector
* \param len length of array
*/
void XGDMatrixSetUIntInfo(void *handle, const char *field, const unsigned *array, size_t len);
/*! /*!
* \brief set label of the training matrix * \brief set label of the training matrix
* \param handle a instance of data matrix * \param handle a instance of data matrix
@ -81,9 +89,17 @@ extern "C" {
* \param handle a instance of data matrix * \param handle a instance of data matrix
* \param field field name * \param field field name
* \param out_len used to set result length * \param out_len used to set result length
* \return pointer to the label * \return pointer to the result
*/ */
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* out_len); const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* out_len);
/*!
* \brief get uint32 info vector from matrix
* \param handle a instance of data matrix
* \param field field name
* \param out_len used to set result length
* \return pointer to the result
*/
const unsigned* XGDMatrixGetUIntInfo(const void *handle, const char *field, size_t* out_len);
/*! /*!
* \brief return number of rows * \brief return number of rows
*/ */