change data format to include weight in binary file, add get weight to python
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
python wrapper for xgboost using ctypes
|
||||
|
||||
see example for usage
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
# append the path to xgboost
|
||||
# append the path to xgboost, you may need to change the following line
|
||||
sys.path.append('../')
|
||||
import xgboost as xgb
|
||||
|
||||
@@ -82,7 +82,7 @@ evallist = [(dtest,'eval'), (dtrain,'train')]
|
||||
bst = xgb.train( param, dtrain, num_round, evallist )
|
||||
|
||||
###
|
||||
# cutomsized loss function, set loss_type to 0, so that predict get untransformed score
|
||||
# advanced: cutomsized loss function, set loss_type to 0, so that predict get untransformed score
|
||||
#
|
||||
print 'start running example to used cutomized objective function'
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ xglib = ctypes.cdll.LoadLibrary(XGBOOST_PATH)
|
||||
xglib.XGDMatrixCreate.restype = ctypes.c_void_p
|
||||
xglib.XGDMatrixNumRow.restype = ctypes.c_ulong
|
||||
xglib.XGDMatrixGetLabel.restype = ctypes.POINTER( ctypes.c_float )
|
||||
xglib.XGDMatrixGetWeight.restype = ctypes.POINTER( ctypes.c_float )
|
||||
xglib.XGDMatrixGetRow.restype = ctypes.POINTER( REntry )
|
||||
xglib.XGBoosterPredict.restype = ctypes.POINTER( ctypes.c_float )
|
||||
|
||||
@@ -81,6 +82,11 @@ class DMatrix:
|
||||
length = ctypes.c_ulong()
|
||||
labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length))
|
||||
return ctypes2numpy( labels, length.value );
|
||||
# get weight from dmatrix
|
||||
def get_weight(self):
|
||||
length = ctypes.c_ulong()
|
||||
weights = xglib.XGDMatrixGetWeight(self.handle, ctypes.byref(length))
|
||||
return ctypes2numpy( weights, length.value );
|
||||
# clear everything
|
||||
def clear(self):
|
||||
xglib.XGDMatrixClear(self.handle)
|
||||
|
||||
@@ -72,6 +72,10 @@ namespace xgboost{
|
||||
*len = this->info.labels.size();
|
||||
return &(this->info.labels[0]);
|
||||
}
|
||||
inline const float* GetWeight( size_t* len ) const{
|
||||
*len = this->info.weights.size();
|
||||
return &(this->info.weights[0]);
|
||||
}
|
||||
inline void CheckInit(void){
|
||||
if(!init_col_){
|
||||
this->data.InitData();
|
||||
@@ -171,6 +175,9 @@ extern "C"{
|
||||
const float* XGDMatrixGetLabel( const void *handle, size_t* len ){
|
||||
return static_cast<const DMatrix*>(handle)->GetLabel(len);
|
||||
}
|
||||
const float* XGDMatrixGetWeight( const void *handle, size_t* len ){
|
||||
return static_cast<const DMatrix*>(handle)->GetWeight(len);
|
||||
}
|
||||
void XGDMatrixClear(void *handle){
|
||||
static_cast<DMatrix*>(handle)->Clear();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user