'''This is a simple script that converts a pickled XGBoost Scikit-Learn interface object from 0.90 to a native model. Pickle format is not stable as it's a direct serialization of Python object. We advice not to use it when stability is needed. ''' import pickle import json import os import argparse import numpy as np import xgboost import warnings def save_label_encoder(le): '''Save the label encoder in XGBClassifier''' meta = dict() for k, v in le.__dict__.items(): if isinstance(v, np.ndarray): meta[k] = v.tolist() else: meta[k] = v return meta def xgboost_skl_90to100(skl_model): '''Extract the model and related metadata in SKL model.''' model = {} with open(skl_model, 'rb') as fd: old = pickle.load(fd) if not isinstance(old, xgboost.XGBModel): raise TypeError( 'The script only handes Scikit-Learn interface object') # Save Scikit-Learn specific Python attributes into a JSON document. for k, v in old.__dict__.items(): if k == '_le': model[k] = save_label_encoder(v) elif k == 'classes_': model[k] = v.tolist() elif k == '_Booster': continue else: try: json.dumps({k: v}) model[k] = v except TypeError: warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.') booster = old.get_booster() # Store the JSON serialization as an attribute booster.set_attr(scikit_learn=json.dumps(model)) # Save it into a native model. i = 0 while True: path = 'xgboost_native_model_from_' + skl_model + '-' + str(i) + '.bin' if os.path.exists(path): i += 1 continue booster.save_model(path) break if __name__ == '__main__': assert xgboost.__version__ != '1.0.0', ('Please use the XGBoost version' ' that generates this pickle.') parser = argparse.ArgumentParser( description=('A simple script to convert pickle generated by' ' XGBoost 0.90 to XGBoost 1.0.0 model (not pickle).')) parser.add_argument( '--old-pickle', type=str, help='Path to old pickle file of Scikit-Learn interface object. ' 'Will output a native model converted from this pickle file', required=True) args = parser.parse_args() xgboost_skl_90to100(args.old_pickle)