From 11fa4197208bf3c3054c93a9609602eb7b8b9c4d Mon Sep 17 00:00:00 2001 From: Skipper Seabold Date: Wed, 6 May 2015 12:33:43 -0500 Subject: [PATCH] ENH: Make XGBModel pickleable. --- wrapper/xgboost.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/wrapper/xgboost.py b/wrapper/xgboost.py index 8ef82b2c7..4a5248818 100644 --- a/wrapper/xgboost.py +++ b/wrapper/xgboost.py @@ -15,6 +15,7 @@ import re import ctypes import platform import collections +from tempfile import NamedTemporaryFile import numpy as np import scipy.sparse @@ -839,6 +840,31 @@ class XGBModel(XGBModelBase): self._Booster = Booster() + def __getstate__(self): + # can't pickle ctypes pointers so save _Booster directly + this = self.__dict__.copy() # don't modify in place + + # delete = False for x-platform compatibility + # https://bugs.python.org/issue14243 + with NamedTemporaryFile(mode="wb", delete=False) as tmp: + this["_Booster"].save_model(tmp.name) + tmp.close() + booster = open(tmp.name, "rb").read() + os.remove(tmp.name) + this.update({"_Booster": booster}) + + return this + + def __setstate__(self, state): + with NamedTemporaryFile(mode="wb", delete=False) as tmp: + tmp.write(state["_Booster"]) + tmp.close() + booster = Booster(model_file=tmp.name) + os.remove(tmp.name) + + state["_Booster"] = booster + self.__dict__.update(state) + def get_xgb_params(self): xgb_params = self.get_params()