python module pass basic test

This commit is contained in:
tqchen 2014-08-17 18:43:25 -07:00
parent af100dd869
commit 301685e0a4
7 changed files with 170 additions and 121 deletions

View File

@ -17,36 +17,17 @@ param = {'bst:max_depth':2, 'bst:eta':1, 'silent':1, 'objective':'binary:logisti
# specify validations set to watch performance # specify validations set to watch performance
evallist = [(dtest,'eval'), (dtrain,'train')] evallist = [(dtest,'eval'), (dtrain,'train')]
num_round = 2 num_round = 2
bst = xgb.train( param, dtrain, num_round, evallist ) bst = xgb.train(param, dtrain, num_round, evallist)
# this is prediction # this is prediction
preds = bst.predict( dtest ) preds = bst.predict(dtest)
labels = dtest.get_label() labels = dtest.get_label()
print ('error=%f' % ( sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) /float(len(preds)))) print ('error=%f' % ( sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) /float(len(preds))))
bst.save_model('0001.model') bst.save_model('0001.model')
# dump model # dump model
bst.dump_model('dump.raw.txt') bst.dump_model('dump.raw.txt')
# dump model with feature map # dump model with feature map
bst.dump_model('dump.raw.txt','featmap.txt') bst.dump_model('dump.nice.txt','featmap.txt')
###
# build dmatrix in python iteratively
#
print ('start running example of build DMatrix in python')
dtrain = xgb.DMatrix()
labels = []
for l in open('agaricus.txt.train'):
arr = l.split()
labels.append( int(arr[0]))
feats = []
for it in arr[1:]:
k,v = it.split(':')
feats.append( (int(k), float(v)) )
dtrain.add_row( feats )
dtrain.set_label( labels )
evallist = [(dtest,'eval'), (dtrain,'train')]
bst = xgb.train( param, dtrain, num_round, evallist )
### ###
# build dmatrix from scipy.sparse # build dmatrix from scipy.sparse
@ -61,7 +42,6 @@ for l in open('agaricus.txt.train'):
k,v = it.split(':') k,v = it.split(':')
row.append(i); col.append(int(k)); dat.append(float(v)) row.append(i); col.append(int(k)); dat.append(float(v))
i += 1 i += 1
csr = scipy.sparse.csr_matrix( (dat, (row,col)) ) csr = scipy.sparse.csr_matrix( (dat, (row,col)) )
dtrain = xgb.DMatrix( csr ) dtrain = xgb.DMatrix( csr )
dtrain.set_label(labels) dtrain.set_label(labels)
@ -71,7 +51,7 @@ bst = xgb.train( param, dtrain, num_round, evallist )
print ('start running example of build DMatrix from numpy array') print ('start running example of build DMatrix from numpy array')
# NOTE: npymat is numpy array, we will convert it into scipy.sparse.csr_matrix in internal implementation,then convert to DMatrix # NOTE: npymat is numpy array, we will convert it into scipy.sparse.csr_matrix in internal implementation,then convert to DMatrix
npymat = csr.todense() npymat = csr.todense()
dtrain = xgb.DMatrix( npymat ) dtrain = xgb.DMatrix( npymat)
dtrain.set_label(labels) dtrain.set_label(labels)
evallist = [(dtest,'eval'), (dtrain,'train')] evallist = [(dtest,'eval'), (dtrain,'train')]
bst = xgb.train( param, dtrain, num_round, evallist ) bst = xgb.train( param, dtrain, num_round, evallist )
@ -81,16 +61,25 @@ bst = xgb.train( param, dtrain, num_round, evallist )
# #
print ('start running example to used cutomized objective function') print ('start running example to used cutomized objective function')
# note: set objective= binary:logistic means the prediction will get logistic transformed # note: for customized objective function, we leave objective as default
# in most case, we may want to leave it as default # note: what we are getting is margin value in prediction
param = {'bst:max_depth':2, 'bst:eta':1, 'silent':1, 'objective':'binary:logistic' } # you must know what you are doing
param = {'bst:max_depth':2, 'bst:eta':1, 'silent':1 }
# user define objective function, given prediction, return gradient and second order gradient # user define objective function, given prediction, return gradient and second order gradient
def logregobj( preds, dtrain ): # this is loglikelihood loss
def logregobj(preds, dtrain):
labels = dtrain.get_label() labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
grad = preds - labels grad = preds - labels
hess = preds * (1.0-preds) hess = preds * (1.0-preds)
return grad, hess return grad, hess
# training with customized objective, we can also do step by step training, simply look at xgboost.py's implementation of train # user defined evaluation function, return a pair metric_name, result
bst = xgb.train( param, dtrain, num_round, evallist, logregobj ) def evalerror(preds, dtrain):
labels = dtrain.get_label()
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
# training with customized objective, we can also do step by step training
# simply look at xgboost.py's implementation of train
bst = xgb.train(param, dtrain, num_round, evallist, logregobj, evalerror)

View File

@ -4,6 +4,7 @@ import ctypes
import os import os
# optinally have scipy sparse, though not necessary # optinally have scipy sparse, though not necessary
import numpy import numpy
import sys
import numpy.ctypeslib import numpy.ctypeslib
import scipy.sparse as scp import scipy.sparse as scp
@ -13,33 +14,39 @@ XGBOOST_PATH = os.path.dirname(__file__)+'/libxgboostwrapper.so'
# load in xgboost library # load in xgboost library
xglib = ctypes.cdll.LoadLibrary(XGBOOST_PATH) xglib = ctypes.cdll.LoadLibrary(XGBOOST_PATH)
xglib.XGDMatrixCreate.restype = ctypes.c_void_p xglib.XGDMatrixCreateFromFile.restype = ctypes.c_void_p
xglib.XGDMatrixCreateFromCSR.restype = ctypes.c_void_p
xglib.XGDMatrixCreateFromMat.restype = ctypes.c_void_p
xglib.XGDMatrixSliceDMatrix.restype = ctypes.c_void_p
xglib.XGDMatrixGetLabel.restype = ctypes.POINTER(ctypes.c_float)
xglib.XGDMatrixGetWeight.restype = ctypes.POINTER(ctypes.c_float)
xglib.XGDMatrixNumRow.restype = ctypes.c_ulong xglib.XGDMatrixNumRow.restype = ctypes.c_ulong
xglib.XGDMatrixGetLabel.restype = ctypes.POINTER( ctypes.c_float )
xglib.XGDMatrixGetWeight.restype = ctypes.POINTER( ctypes.c_float )
xglib.XGDMatrixGetRow.restype = ctypes.POINTER( REntry )
xglib.XGBoosterCreate.restype = ctypes.c_void_p
xglib.XGBoosterPredict.restype = ctypes.POINTER( ctypes.c_float )
def ctypes2numpy( cptr, length ): xglib.XGBoosterCreate.restype = ctypes.c_void_p
xglib.XGBoosterPredict.restype = ctypes.POINTER(ctypes.c_float)
xglib.XGBoosterEvalOneIter.restype = ctypes.c_char_p
xglib.XGBoosterDumpModel.restype = ctypes.POINTER(ctypes.c_char_p)
def ctypes2numpy(cptr, length):
# convert a ctypes pointer array to numpy # convert a ctypes pointer array to numpy
assert isinstance( cptr, ctypes.POINTER( ctypes.c_float ) ) assert isinstance(cptr, ctypes.POINTER(ctypes.c_float))
res = numpy.zeros( length, dtype='float32' ) res = numpy.zeros(length, dtype='float32')
assert ctypes.memmove( res.ctypes.data, cptr, length * res.strides[0] ) assert ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0])
return res return res
# data matrix used in xgboost # data matrix used in xgboost
class DMatrix: class DMatrix:
# constructor # constructor
def __init__(self, data=None, label=None, missing=0.0, weight = None): def __init__(self, data, label=None, missing=0.0, weight = None):
# force into void_p, mac need to pass things in as void_p # force into void_p, mac need to pass things in as void_p
self.handle = ctypes.c_void_p( xglib.XGDMatrixCreate() )
if data == None: if data == None:
self.handle = None
return return
if isinstance(data,str): if isinstance(data, str):
xglib.XGDMatrixLoad(self.handle, ctypes.c_char_p(data.encode('utf-8')), 1) self.handle = xglib.XGDMatrixCreateFromFile(ctypes.c_char_p(data.encode('utf-8')), 1)
elif isinstance(data,scp.csr_matrix): elif isinstance(data, scp.csr_matrix):
self.__init_from_csr(data) self.__init_from_csr(data)
elif isinstance(data, numpy.ndarray) and len(data.shape) == 2: elif isinstance(data, numpy.ndarray) and len(data.shape) == 2:
self.__init_from_npy2d(data, missing) self.__init_from_npy2d(data, missing)
else: else:
@ -52,77 +59,68 @@ class DMatrix:
self.set_label(label) self.set_label(label)
if weight !=None: if weight !=None:
self.set_weight(weight) self.set_weight(weight)
# convert data from csr matrix # convert data from csr matrix
def __init_from_csr(self,csr): def __init_from_csr(self, csr):
assert len(csr.indices) == len(csr.data) assert len(csr.indices) == len(csr.data)
xglib.XGDMatrixParseCSR( self.handle, self.handle = xglib.XGDMatrixCreateFromCSR(
( ctypes.c_ulong * len(csr.indptr) )(*csr.indptr), (ctypes.c_ulong * len(csr.indptr))(*csr.indptr),
( ctypes.c_uint * len(csr.indices) )(*csr.indices), (ctypes.c_uint * len(csr.indices))(*csr.indices),
( ctypes.c_float * len(csr.data) )(*csr.data), (ctypes.c_float * len(csr.data))(*csr.data),
len(csr.indptr), len(csr.data) ) len(csr.indptr), len(csr.data))
# convert data from numpy matrix # convert data from numpy matrix
def __init_from_npy2d(self,mat,missing): def __init_from_npy2d(self,mat,missing):
data = numpy.array( mat.reshape(mat.size), dtype='float32' ) data = numpy.array(mat.reshape(mat.size), dtype='float32')
xglib.XGDMatrixParseMat( self.handle, self.handle = xglib.XGDMatrixCreateFromMat(
data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
mat.shape[0], mat.shape[1], ctypes.c_float(missing) ) mat.shape[0], mat.shape[1], ctypes.c_float(missing))
# destructor # destructor
def __del__(self): def __del__(self):
xglib.XGDMatrixFree(self.handle) xglib.XGDMatrixFree(self.handle)
# load data from file
def load(self, fname, silent=True):
xglib.XGDMatrixLoad(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent))
# load data from file # load data from file
def save_binary(self, fname, silent=True): def save_binary(self, fname, silent=True):
xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent)) xglib.XGDMatrixSaveBinary(self.handle, ctypes.c_char_p(fname.encode('utf-8')), int(silent))
# set label of dmatrix # set label of dmatrix
def set_label(self, label): def set_label(self, label):
xglib.XGDMatrixSetLabel(self.handle, (ctypes.c_float*len(label))(*label), len(label) ) xglib.XGDMatrixSetLabel(self.handle, (ctypes.c_float*len(label))(*label), len(label))
# set group size of dmatrix, used for rank # set group size of dmatrix, used for rank
def set_group(self, group): def set_group(self, group):
xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group) ) xglib.XGDMatrixSetGroup(self.handle, (ctypes.c_uint*len(group))(*group), len(group))
# set weight of each instances # set weight of each instances
def set_weight(self, weight): def set_weight(self, weight):
xglib.XGDMatrixSetWeight(self.handle, (ctypes.c_float*len(weight))(*weight), len(weight) ) xglib.XGDMatrixSetWeight(self.handle, (ctypes.c_float*len(weight))(*weight), len(weight))
# get label from dmatrix # get label from dmatrix
def get_label(self): def get_label(self):
length = ctypes.c_ulong() length = ctypes.c_ulong()
labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length)) labels = xglib.XGDMatrixGetLabel(self.handle, ctypes.byref(length))
return ctypes2numpy( labels, length.value ); return ctypes2numpy(labels, length.value)
# get weight from dmatrix # get weight from dmatrix
def get_weight(self): def get_weight(self):
length = ctypes.c_ulong() length = ctypes.c_ulong()
weights = xglib.XGDMatrixGetWeight(self.handle, ctypes.byref(length)) weights = xglib.XGDMatrixGetWeight(self.handle, ctypes.byref(length))
return ctypes2numpy( weights, length.value ); return ctypes2numpy(weights, length.value)
# clear everything
def clear(self):
xglib.XGDMatrixClear(self.handle)
def num_row(self): def num_row(self):
return xglib.XGDMatrixNumRow(self.handle) return xglib.XGDMatrixNumRow(self.handle)
# append a row to DMatrix # slice the DMatrix to return a new DMatrix that only contains rindex
def add_row(self, row): def slice(self, rindex):
xglib.XGDMatrixAddRow(self.handle, (REntry*len(row))(*row), len(row) ) res = DMatrix(None)
# get n-throw from DMatrix res.handle = xglib.XGDMatrixSliceDMatrix(
def __getitem__(self, ridx): self.handle, (ctypes.c_int*len(rindex))(*rindex), len(rindex))
length = ctypes.c_ulong() return res
row = xglib.XGDMatrixGetRow(self.handle, ridx, ctypes.byref(length) );
return [ (int(row[i].findex),row[i].fvalue) for i in range(length.value) ]
class Booster: class Booster:
"""learner class """ """learner class """
def __init__(self, params={}, cache=[]): def __init__(self, params={}, cache=[]):
""" constructor, param: """ """ constructor, param: """
for d in cache: for d in cache:
assert isinstance(d,DMatrix) assert isinstance(d, DMatrix)
dmats = ( ctypes.c_void_p * len(cache) )(*[ d.handle for d in cache]) dmats = (ctypes.c_void_p * len(cache))(*[ d.handle for d in cache])
self.handle = ctypes.c_void_p( xglib.XGBoosterCreate( dmats, len(cache) ) ) self.handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, len(cache)))
self.set_param( {'seed':0} ) self.set_param({'seed':0})
self.set_param( params ) self.set_param(params)
def __del__(self): def __del__(self):
xglib.XGBoosterFree(self.handle) xglib.XGBoosterFree(self.handle)
def set_param(self, params, pv=None): def set_param(self, params, pv=None):
if isinstance(params,dict): if isinstance(params, dict):
for k, v in params.items(): for k, v in params.items():
xglib.XGBoosterSetParam( xglib.XGBoosterSetParam(
self.handle, ctypes.c_char_p(k.encode('utf-8')), self.handle, ctypes.c_char_p(k.encode('utf-8')),
@ -130,72 +128,112 @@ class Booster:
elif isinstance(params,str) and pv != None: elif isinstance(params,str) and pv != None:
xglib.XGBoosterSetParam( xglib.XGBoosterSetParam(
self.handle, ctypes.c_char_p(params.encode('utf-8')), self.handle, ctypes.c_char_p(params.encode('utf-8')),
ctypes.c_char_p(str(pv).encode('utf-8')) ) ctypes.c_char_p(str(pv).encode('utf-8')))
else: else:
for k, v in params: for k, v in params:
xglib.XGBoosterSetParam( xglib.XGBoosterSetParam(
self.handle, ctypes.c_char_p(k.encode('utf-8')), self.handle, ctypes.c_char_p(k.encode('utf-8')),
ctypes.c_char_p(str(v).encode('utf-8')) ) ctypes.c_char_p(str(v).encode('utf-8')))
def update(self, dtrain): def update(self, dtrain, it):
""" update """ """ update """
assert isinstance(dtrain, DMatrix) assert isinstance(dtrain, DMatrix)
xglib.XGBoosterUpdateOneIter( self.handle, dtrain.handle ) xglib.XGBoosterUpdateOneIter(self.handle, it, dtrain.handle)
def boost(self, dtrain, grad, hess, bst_group = -1): def boost(self, dtrain, grad, hess):
""" update """ """ update """
assert len(grad) == len(hess) assert len(grad) == len(hess)
assert isinstance(dtrain, DMatrix) assert isinstance(dtrain, DMatrix)
xglib.XGBoosterBoostOneIter( self.handle, dtrain.handle, xglib.XGBoosterBoostOneIter(self.handle, dtrain.handle,
(ctypes.c_float*len(grad))(*grad), (ctypes.c_float*len(grad))(*grad),
(ctypes.c_float*len(hess))(*hess), (ctypes.c_float*len(hess))(*hess),
len(grad), bst_group ) len(grad))
def update_interact(self, dtrain, action, booster_index=None):
""" beta: update with specified action"""
assert isinstance(dtrain, DMatrix)
if booster_index != None:
self.set_param('interact:booster_index', str(booster_index))
xglib.XGBoosterUpdateInteract(
self.handle, dtrain.handle, ctypes.c_char_p(str(action)) )
def eval_set(self, evals, it = 0): def eval_set(self, evals, it = 0):
for d in evals: for d in evals:
assert isinstance(d[0], DMatrix) assert isinstance(d[0], DMatrix)
assert isinstance(d[1], str) assert isinstance(d[1], str)
dmats = ( ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals]) dmats = (ctypes.c_void_p * len(evals) )(*[ d[0].handle for d in evals])
evnames = ( ctypes.c_char_p * len(evals) )( evnames = (ctypes.c_char_p * len(evals))(
*[ctypes.c_char_p(d[1].encode('utf-8')) for d in evals]) * [ctypes.c_char_p(d[1].encode('utf-8')) for d in evals])
xglib.XGBoosterEvalOneIter( self.handle, it, dmats, evnames, len(evals) ) return xglib.XGBoosterEvalOneIter(self.handle, it, dmats, evnames, len(evals))
def eval(self, mat, name = 'eval', it = 0 ): def eval(self, mat, name = 'eval', it = 0):
self.eval_set( [(mat,name)], it) return self.eval_set( [(mat,name)], it)
def predict(self, data, bst_group = -1): def predict(self, data):
length = ctypes.c_ulong() length = ctypes.c_ulong()
preds = xglib.XGBoosterPredict( self.handle, data.handle, ctypes.byref(length), bst_group) preds = xglib.XGBoosterPredict(self.handle, data.handle, ctypes.byref(length))
return ctypes2numpy( preds, length.value ) return ctypes2numpy(preds, length.value)
def save_model(self, fname): def save_model(self, fname):
""" save model to file """ """ save model to file """
xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8'))) xglib.XGBoosterSaveModel(self.handle, ctypes.c_char_p(fname.encode('utf-8')))
def load_model(self, fname): def load_model(self, fname):
"""load model from file""" """load model from file"""
xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) ) xglib.XGBoosterLoadModel( self.handle, ctypes.c_char_p(fname.encode('utf-8')) )
def dump_model(self, fname, fmap=''): def dump_model(self, fo, fmap=''):
"""dump model into text file""" """dump model into text file"""
xglib.XGBoosterDumpModel( if isinstance(fo,str):
self.handle, ctypes.c_char_p(fname.encode('utf-8')), fo = open(fo,'w')
ctypes.c_char_p(fmap.encode('utf-8'))) need_close = True
else:
need_close = False
ret = self.get_dump(fmap)
for i in range(len(ret)):
fo.write('booster[%d]:\n' %i)
fo.write( ret[i] )
if need_close:
fo.close()
def get_dump(self, fmap=''):
"""get dump of model as list of strings """
length = ctypes.c_ulong()
sarr = xglib.XGBoosterDumpModel(self.handle, ctypes.c_char_p(fmap.encode('utf-8')), ctypes.byref(length))
res = []
for i in range(length.value):
res.append( str(sarr[i]) )
return res
def get_fscore(self, fmap=''):
""" get feature importance of each feature """
trees = self.get_dump(fmap)
fmap = {}
for tree in trees:
print tree
for l in tree.split('\n'):
arr = l.split('[')
if len(arr) == 1:
continue
fid = arr[1].split(']')[0]
fid = fid.split('<')[0]
if fid not in fmap:
fmap[fid] = 1
else:
fmap[fid]+= 1
return fmap
def train(params, dtrain, num_boost_round = 10, evals = [], obj=None): def evaluate(bst, evals, it, feval = None):
"""evaluation on eval set"""
if feval != None:
res = '[%d]' % it
for dm, evname in evals:
name, val = feval(bst.predict(dm), dm)
res += '\t%s-%s:%f' % (evname, name, val)
else:
res = bst.eval_set(evals, it)
return res
def train(params, dtrain, num_boost_round = 10, evals = [], obj=None, feval=None):
""" train a booster with given paramaters """ """ train a booster with given paramaters """
bst = Booster(params, [dtrain]+[ d[0] for d in evals ] ) bst = Booster(params, [dtrain]+[ d[0] for d in evals ] )
if obj == None: if obj == None:
for i in range(num_boost_round): for i in range(num_boost_round):
bst.update( dtrain ) bst.update( dtrain, i )
if len(evals) != 0: if len(evals) != 0:
bst.eval_set( evals, i ) sys.stderr.write(evaluate(bst, evals, i, feval)+'\n')
else: else:
if len(evals) != 0 and feval == None:
print 'you need to provide your own evaluation function'
# try customized objective function # try customized objective function
for i in range(num_boost_round): for i in range(num_boost_round):
pred = bst.predict( dtrain ) pred = bst.predict( dtrain )
grad, hess = obj( pred, dtrain ) grad, hess = obj( pred, dtrain )
bst.boost( dtrain, grad, hess ) bst.boost( dtrain, grad, hess )
if len(evals) != 0: if len(evals) != 0:
bst.eval_set( evals, i ) sys.stderr.write(evaluate(bst, evals, i, feval)+'\n')
return bst return bst

View File

@ -20,9 +20,11 @@ class Booster: public learner::BoostLearner<FMatrixS> {
public: public:
explicit Booster(const std::vector<DataMatrix*>& mats) { explicit Booster(const std::vector<DataMatrix*>& mats) {
this->silent = 1; this->silent = 1;
this->init_model = false;
this->SetCacheData(mats); this->SetCacheData(mats);
} }
const float *Pred(const DataMatrix &dmat, size_t *len) { const float *Pred(const DataMatrix &dmat, size_t *len) {
this->CheckInitModel();
this->Predict(dmat, &this->preds_); this->Predict(dmat, &this->preds_);
*len = this->preds_.size(); *len = this->preds_.size();
return &this->preds_[0]; return &this->preds_[0];
@ -37,6 +39,15 @@ class Booster: public learner::BoostLearner<FMatrixS> {
} }
gbm_->DoBoost(gpair_, train.fmat, train.info.root_index); gbm_->DoBoost(gpair_, train.fmat, train.info.root_index);
} }
inline void CheckInitModel(void) {
if (!init_model) {
this->InitModel(); init_model = true;
}
}
inline void LoadModel(const char *fname) {
learner::BoostLearner<FMatrixS>::LoadModel(fname);
this->init_model = true;
}
inline const char** GetModelDump(const utils::FeatMap& fmap, bool with_stats, size_t *len) { inline const char** GetModelDump(const utils::FeatMap& fmap, bool with_stats, size_t *len) {
model_dump = this->DumpModel(fmap, with_stats); model_dump = this->DumpModel(fmap, with_stats);
model_dump_cptr.resize(model_dump.size()); model_dump_cptr.resize(model_dump.size());
@ -52,6 +63,9 @@ class Booster: public learner::BoostLearner<FMatrixS> {
// temporal space to save model dump // temporal space to save model dump
std::vector<std::string> model_dump; std::vector<std::string> model_dump;
std::vector<const char*> model_dump_cptr; std::vector<const char*> model_dump_cptr;
private:
bool init_model;
}; };
} // namespace wrapper } // namespace wrapper
} // namespace xgboost } // namespace xgboost
@ -199,6 +213,7 @@ extern "C"{
void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain) { void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain) {
Booster *bst = static_cast<Booster*>(handle); Booster *bst = static_cast<Booster*>(handle);
DataMatrix *dtr = static_cast<DataMatrix*>(dtrain); DataMatrix *dtr = static_cast<DataMatrix*>(dtrain);
bst->CheckInitModel();
bst->CheckInit(dtr); bst->CheckInit(dtr);
bst->UpdateOneIter(iter, *dtr); bst->UpdateOneIter(iter, *dtr);
} }
@ -206,6 +221,7 @@ extern "C"{
float *grad, float *hess, size_t len) { float *grad, float *hess, size_t len) {
Booster *bst = static_cast<Booster*>(handle); Booster *bst = static_cast<Booster*>(handle);
DataMatrix *dtr = static_cast<DataMatrix*>(dtrain); DataMatrix *dtr = static_cast<DataMatrix*>(dtrain);
bst->CheckInitModel();
bst->CheckInit(dtr); bst->CheckInit(dtr);
bst->BoostOneIter(*dtr, grad, hess, len); bst->BoostOneIter(*dtr, grad, hess, len);
} }
@ -217,6 +233,7 @@ extern "C"{
mats.push_back(static_cast<DataMatrix*>(dmats[i])); mats.push_back(static_cast<DataMatrix*>(dmats[i]));
names.push_back(std::string(evnames[i])); names.push_back(std::string(evnames[i]));
} }
bst->CheckInitModel();
bst->eval_str = bst->EvalOneIter(iter, mats, names); bst->eval_str = bst->EvalOneIter(iter, mats, names);
return bst->eval_str.c_str(); return bst->eval_str.c_str();
} }

View File

@ -242,7 +242,7 @@ class FMatrixS : public FMatrixInterface<FMatrixS>{
* \brief save column access data into stream * \brief save column access data into stream
* \param fo output stream to save to * \param fo output stream to save to
*/ */
inline void SaveColAccess(utils::IStream &fo) { inline void SaveColAccess(utils::IStream &fo) const {
fo.Write(&num_buffered_row_, sizeof(num_buffered_row_)); fo.Write(&num_buffered_row_, sizeof(num_buffered_row_));
if (num_buffered_row_ != 0) { if (num_buffered_row_ != 0) {
SaveBinary(fo, col_ptr_, col_data_); SaveBinary(fo, col_ptr_, col_data_);

View File

@ -15,7 +15,12 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) {
} }
void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) { void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) {
utils::Error("not implemented"); if (dmat.magic == DMatrixSimple::kMagic){
const DMatrixSimple *p_dmat = static_cast<const DMatrixSimple*>(&dmat);
p_dmat->SaveBinary(fname, silent);
} else {
utils::Error("not implemented");
}
} }
} // namespace io } // namespace io

View File

@ -148,7 +148,7 @@ class DMatrixSimple : public DataMatrix {
* \param fname name of binary data * \param fname name of binary data
* \param silent whether print information or not * \param silent whether print information or not
*/ */
inline void SaveBinary(const char* fname, bool silent = false) { inline void SaveBinary(const char* fname, bool silent = false) const {
utils::FileStream fs(utils::FopenCheck(fname, "wb")); utils::FileStream fs(utils::FopenCheck(fname, "wb"));
int magic = kMagic; int magic = kMagic;
fs.Write(&magic, sizeof(magic)); fs.Write(&magic, sizeof(magic));

View File

@ -58,7 +58,7 @@ struct MetaInfo {
return 0; return 0;
} }
} }
inline void SaveBinary(utils::IStream &fo) { inline void SaveBinary(utils::IStream &fo) const {
fo.Write(&num_row, sizeof(num_row)); fo.Write(&num_row, sizeof(num_row));
fo.Write(&num_col, sizeof(num_col)); fo.Write(&num_col, sizeof(num_col));
fo.Write(labels); fo.Write(labels);