* Port test model compatibility. * Port logit model fix. https://github.com/dmlc/xgboost/pull/5248 https://github.com/dmlc/xgboost/pull/5281
80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
'''This is a simple script that converts a pickled XGBoost
|
|
Scikit-Learn interface object from 0.90 to a native model. Pickle
|
|
format is not stable as it's a direct serialization of Python object.
|
|
We advice not to use it when stability is needed.
|
|
|
|
'''
|
|
import pickle
|
|
import json
|
|
import os
|
|
import argparse
|
|
import numpy as np
|
|
import xgboost
|
|
import warnings
|
|
|
|
|
|
def save_label_encoder(le):
|
|
'''Save the label encoder in XGBClassifier'''
|
|
meta = dict()
|
|
for k, v in le.__dict__.items():
|
|
if isinstance(v, np.ndarray):
|
|
meta[k] = v.tolist()
|
|
else:
|
|
meta[k] = v
|
|
return meta
|
|
|
|
|
|
def xgboost_skl_90to100(skl_model):
|
|
'''Extract the model and related metadata in SKL model.'''
|
|
model = {}
|
|
with open(skl_model, 'rb') as fd:
|
|
old = pickle.load(fd)
|
|
if not isinstance(old, xgboost.XGBModel):
|
|
raise TypeError(
|
|
'The script only handes Scikit-Learn interface object')
|
|
|
|
# Save Scikit-Learn specific Python attributes into a JSON document.
|
|
for k, v in old.__dict__.items():
|
|
if k == '_le':
|
|
model[k] = save_label_encoder(v)
|
|
elif k == 'classes_':
|
|
model[k] = v.tolist()
|
|
elif k == '_Booster':
|
|
continue
|
|
else:
|
|
try:
|
|
json.dumps({k: v})
|
|
model[k] = v
|
|
except TypeError:
|
|
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
|
|
booster = old.get_booster()
|
|
# Store the JSON serialization as an attribute
|
|
booster.set_attr(scikit_learn=json.dumps(model))
|
|
|
|
# Save it into a native model.
|
|
i = 0
|
|
while True:
|
|
path = 'xgboost_native_model_from_' + skl_model + '-' + str(i) + '.bin'
|
|
if os.path.exists(path):
|
|
i += 1
|
|
continue
|
|
booster.save_model(path)
|
|
break
|
|
|
|
|
|
if __name__ == '__main__':
|
|
assert xgboost.__version__ != '1.0.0', ('Please use the XGBoost version'
|
|
' that generates this pickle.')
|
|
parser = argparse.ArgumentParser(
|
|
description=('A simple script to convert pickle generated by'
|
|
' XGBoost 0.90 to XGBoost 1.0.0 model (not pickle).'))
|
|
parser.add_argument(
|
|
'--old-pickle',
|
|
type=str,
|
|
help='Path to old pickle file of Scikit-Learn interface object. '
|
|
'Will output a native model converted from this pickle file',
|
|
required=True)
|
|
args = parser.parse_args()
|
|
|
|
xgboost_skl_90to100(args.old_pickle)
|