add interact mode
This commit is contained in:
parent
6fd77cbb24
commit
2057dda560
@ -11,7 +11,7 @@ dtrain = xgb.DMatrix('agaricus.txt.train')
|
|||||||
dtest = xgb.DMatrix('agaricus.txt.test')
|
dtest = xgb.DMatrix('agaricus.txt.test')
|
||||||
|
|
||||||
# specify parameters via map, definition are same as c++ version
|
# specify parameters via map, definition are same as c++ version
|
||||||
param = {'bst:max_depth':4, 'bst:eta':1, 'silent':1, 'loss_type':2 }
|
param = {'bst:max_depth':2, 'bst:eta':1, 'silent':1, 'loss_type':2 }
|
||||||
|
|
||||||
# specify validations set to watch performance
|
# specify validations set to watch performance
|
||||||
evallist = [(dtest,'eval'), (dtrain,'train')]
|
evallist = [(dtest,'eval'), (dtrain,'train')]
|
||||||
@ -23,7 +23,15 @@ preds = bst.predict( dtest )
|
|||||||
labels = dtest.get_label()
|
labels = dtest.get_label()
|
||||||
print 'error=%f' % ( sum(1 for i in xrange(len(preds)) if int(preds[i]>0.5)!=labels[i]) /float(len(preds)))
|
print 'error=%f' % ( sum(1 for i in xrange(len(preds)) if int(preds[i]>0.5)!=labels[i]) /float(len(preds)))
|
||||||
bst.save_model('0001.model')
|
bst.save_model('0001.model')
|
||||||
|
# dump model
|
||||||
|
bst.dump_model('dump.raw.txt')
|
||||||
|
# dump model with feature map
|
||||||
|
bst.dump_model('dump.raw.txt','featmap.txt')
|
||||||
|
|
||||||
|
# beta: interact mode
|
||||||
|
bst.set_param('bst:interact:expand',4)
|
||||||
|
bst.update_interact( dtrain, 'update', 0)
|
||||||
|
bst.dump_model('dump.raw2.txt')
|
||||||
|
|
||||||
###
|
###
|
||||||
# build dmatrix in python iteratively
|
# build dmatrix in python iteratively
|
||||||
|
|||||||
@ -95,12 +95,26 @@ class Booster:
|
|||||||
assert isinstance(d,DMatrix)
|
assert isinstance(d,DMatrix)
|
||||||
dmats = ( ctypes.c_void_p * len(cache) )(*[ ctypes.c_void_p(d.handle) for d in cache])
|
dmats = ( ctypes.c_void_p * len(cache) )(*[ ctypes.c_void_p(d.handle) for d in cache])
|
||||||
self.handle = xglib.XGBoosterCreate( dmats, len(cache) )
|
self.handle = xglib.XGBoosterCreate( dmats, len(cache) )
|
||||||
|
self.set_param( params )
|
||||||
|
def set_param(self, params,pv=None):
|
||||||
|
if isinstance(params,dict):
|
||||||
for k, v in params.iteritems():
|
for k, v in params.iteritems():
|
||||||
xglib.XGBoosterSetParam( self.handle, ctypes.c_char_p(k), ctypes.c_char_p(str(v)) )
|
xglib.XGBoosterSetParam( self.handle, ctypes.c_char_p(k), ctypes.c_char_p(str(v)) )
|
||||||
|
elif isinstance(params,str) and pv != None:
|
||||||
|
xglib.XGBoosterSetParam( self.handle, ctypes.c_char_p(params), ctypes.c_char_p(str(pv)) )
|
||||||
|
else:
|
||||||
|
for k, v in params:
|
||||||
|
xglib.XGBoosterSetParam( self.handle, ctypes.c_char_p(k), ctypes.c_char_p(str(v)) )
|
||||||
def update(self, dtrain):
|
def update(self, dtrain):
|
||||||
""" update """
|
""" update """
|
||||||
assert isinstance(dtrain, DMatrix)
|
assert isinstance(dtrain, DMatrix)
|
||||||
xglib.XGBoosterUpdateOneIter( self.handle, dtrain.handle )
|
xglib.XGBoosterUpdateOneIter( self.handle, dtrain.handle )
|
||||||
|
def update_interact(self, dtrain, action, booster_index=None):
|
||||||
|
""" beta: update with specified action"""
|
||||||
|
assert isinstance(dtrain, DMatrix)
|
||||||
|
if booster_index != None:
|
||||||
|
self.set_param('interact:booster_index', str(booster_index))
|
||||||
|
xglib.XGBoosterUpdateInteract( self.handle, dtrain.handle, ctypes.c_char_p(str(action)) )
|
||||||
def eval_set(self, evals, it = 0):
|
def eval_set(self, evals, it = 0):
|
||||||
for d in evals:
|
for d in evals:
|
||||||
assert isinstance(d[0], DMatrix)
|
assert isinstance(d[0], DMatrix)
|
||||||
|
|||||||
@ -210,5 +210,13 @@ extern "C"{
|
|||||||
static_cast<Booster*>(handle)->DumpModel( fo, featmap, false );
|
static_cast<Booster*>(handle)->DumpModel( fo, featmap, false );
|
||||||
fclose( fo );
|
fclose( fo );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void XGBoosterUpdateInteract( void *handle, void *dtrain, const char *action ){
|
||||||
|
Booster *bst = static_cast<Booster*>(handle);
|
||||||
|
DMatrix *dtr = static_cast<DMatrix*>(dtrain);
|
||||||
|
bst->CheckInit(); dtr->CheckInit();
|
||||||
|
std::string act( action );
|
||||||
|
bst->UpdateInteract( act, *dtr );
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -157,6 +157,13 @@ extern "C"{
|
|||||||
* \param fmap name to fmap can be empty string
|
* \param fmap name to fmap can be empty string
|
||||||
*/
|
*/
|
||||||
void XGBoosterDumpModel( void *handle, const char *fname, const char *fmap );
|
void XGBoosterDumpModel( void *handle, const char *fname, const char *fmap );
|
||||||
|
/*!
|
||||||
|
* \brief interactively update model: beta
|
||||||
|
* \param handle handle
|
||||||
|
* \param dtrain training data
|
||||||
|
* \param action action name
|
||||||
|
*/
|
||||||
|
void XGBoosterUpdateInteract( void *handle, void *dtrain, const char* action );
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user