allow booster to be pickable, add copy function
This commit is contained in:
parent
39f1da08d2
commit
e6b8b23a2c
@ -1,7 +1,9 @@
|
|||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
|
import pickle
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
|
import copy
|
||||||
|
|
||||||
### simple example
|
### simple example
|
||||||
# load file from text file, also binary buffer generated by xgboost
|
# load file from text file, also binary buffer generated by xgboost
|
||||||
@ -28,6 +30,7 @@ bst.dump_model('dump.nice.txt','../data/featmap.txt')
|
|||||||
|
|
||||||
# save dmatrix into binary buffer
|
# save dmatrix into binary buffer
|
||||||
dtest.save_binary('dtest.buffer')
|
dtest.save_binary('dtest.buffer')
|
||||||
|
# save model
|
||||||
bst.save_model('xgb.model')
|
bst.save_model('xgb.model')
|
||||||
# load model and data in
|
# load model and data in
|
||||||
bst2 = xgb.Booster(model_file='xgb.model')
|
bst2 = xgb.Booster(model_file='xgb.model')
|
||||||
@ -36,6 +39,14 @@ preds2 = bst2.predict(dtest2)
|
|||||||
# assert they are the same
|
# assert they are the same
|
||||||
assert np.sum(np.abs(preds2-preds)) == 0
|
assert np.sum(np.abs(preds2-preds)) == 0
|
||||||
|
|
||||||
|
# alternatively, you can pickle the booster
|
||||||
|
pks = pickle.dumps(bst2)
|
||||||
|
# load model and data in
|
||||||
|
bst3 = pickle.loads(pks)
|
||||||
|
preds3 = bst2.predict(dtest2)
|
||||||
|
# assert they are the same
|
||||||
|
assert np.sum(np.abs(preds3-preds)) == 0
|
||||||
|
|
||||||
###
|
###
|
||||||
# build dmatrix from scipy.sparse
|
# build dmatrix from scipy.sparse
|
||||||
print ('start running example of build DMatrix from scipy.sparse CSR Matrix')
|
print ('start running example of build DMatrix from scipy.sparse CSR Matrix')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user