diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index f134c0399..3a5e6af80 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1052,7 +1052,7 @@ class Booster(object): _check_call( _LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length)) self.__dict__.update(state) - elif isinstance(model_file, (STRING_TYPES, os_PathLike)): + elif isinstance(model_file, (STRING_TYPES, os_PathLike, bytearray)): self.load_model(model_file) elif model_file is None: pass @@ -1512,7 +1512,8 @@ class Booster(object): return ctypes2buffer(cptr, length.value) def load_model(self, fname): - """Load the model from a file, local or as URI. + """Load the model from a file or bytearray. Path to file can be local + or as an URI. The model is loaded from an XGBoost format which is universal among the various XGBoost interfaces. Auxiliary attributes of the Python Booster @@ -1530,6 +1531,12 @@ class Booster(object): # from URL. _check_call(_LIB.XGBoosterLoadModel( self.handle, c_str(os_fspath(fname)))) + elif isinstance(fname, bytearray): + buf = fname + length = c_bst_ulong(len(buf)) + ptr = (ctypes.c_char * len(buf)).from_buffer(buf) + _check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, + length)) else: raise TypeError('Unknown file type: ', fname) diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 9d947f1f4..c94012b2a 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -300,6 +300,13 @@ class TestModels(unittest.TestCase): assert float(config['learner']['objective'][ 'reg_loss_param']['scale_pos_weight']) == 0.5 + buf = bst.save_raw() + from_raw = xgb.Booster() + from_raw.load_model(buf) + + buf_from_raw = from_raw.save_raw() + assert buf == buf_from_raw + def test_model_json_io(self): loc = locale.getpreferredencoding(False) model_path = 'test_model_json_io.json'