Merge model compatibility fixes from 1.0rc branch. (#5305)
* Port test model compatibility. * Port logit model fix. https://github.com/dmlc/xgboost/pull/5248 https://github.com/dmlc/xgboost/pull/5281
This commit is contained in:
79
doc/python/convert_090to100.py
Normal file
79
doc/python/convert_090to100.py
Normal file
@@ -0,0 +1,79 @@
|
||||
'''This is a simple script that converts a pickled XGBoost
|
||||
Scikit-Learn interface object from 0.90 to a native model. Pickle
|
||||
format is not stable as it's a direct serialization of Python object.
|
||||
We advice not to use it when stability is needed.
|
||||
|
||||
'''
|
||||
import pickle
|
||||
import json
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import xgboost
|
||||
import warnings
|
||||
|
||||
|
||||
def save_label_encoder(le):
|
||||
'''Save the label encoder in XGBClassifier'''
|
||||
meta = dict()
|
||||
for k, v in le.__dict__.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
meta[k] = v.tolist()
|
||||
else:
|
||||
meta[k] = v
|
||||
return meta
|
||||
|
||||
|
||||
def xgboost_skl_90to100(skl_model):
|
||||
'''Extract the model and related metadata in SKL model.'''
|
||||
model = {}
|
||||
with open(skl_model, 'rb') as fd:
|
||||
old = pickle.load(fd)
|
||||
if not isinstance(old, xgboost.XGBModel):
|
||||
raise TypeError(
|
||||
'The script only handes Scikit-Learn interface object')
|
||||
|
||||
# Save Scikit-Learn specific Python attributes into a JSON document.
|
||||
for k, v in old.__dict__.items():
|
||||
if k == '_le':
|
||||
model[k] = save_label_encoder(v)
|
||||
elif k == 'classes_':
|
||||
model[k] = v.tolist()
|
||||
elif k == '_Booster':
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
json.dumps({k: v})
|
||||
model[k] = v
|
||||
except TypeError:
|
||||
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
|
||||
booster = old.get_booster()
|
||||
# Store the JSON serialization as an attribute
|
||||
booster.set_attr(scikit_learn=json.dumps(model))
|
||||
|
||||
# Save it into a native model.
|
||||
i = 0
|
||||
while True:
|
||||
path = 'xgboost_native_model_from_' + skl_model + '-' + str(i) + '.bin'
|
||||
if os.path.exists(path):
|
||||
i += 1
|
||||
continue
|
||||
booster.save_model(path)
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
assert xgboost.__version__ != '1.0.0', ('Please use the XGBoost version'
|
||||
' that generates this pickle.')
|
||||
parser = argparse.ArgumentParser(
|
||||
description=('A simple script to convert pickle generated by'
|
||||
' XGBoost 0.90 to XGBoost 1.0.0 model (not pickle).'))
|
||||
parser.add_argument(
|
||||
'--old-pickle',
|
||||
type=str,
|
||||
help='Path to old pickle file of Scikit-Learn interface object. '
|
||||
'Will output a native model converted from this pickle file',
|
||||
required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
xgboost_skl_90to100(args.old_pickle)
|
||||
Reference in New Issue
Block a user