Restore loading model from buffer. (#5360) (#5366)

This commit is contained in:
Jiaming Yuan 2020-02-26 14:23:10 +08:00 committed by GitHub
parent 7d178cbd25
commit 2bc5d8d449
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -1052,7 +1052,7 @@ class Booster(object):
_check_call( _check_call(
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length)) _LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
self.__dict__.update(state) self.__dict__.update(state)
elif isinstance(model_file, (STRING_TYPES, os_PathLike)): elif isinstance(model_file, (STRING_TYPES, os_PathLike, bytearray)):
self.load_model(model_file) self.load_model(model_file)
elif model_file is None: elif model_file is None:
pass pass
@ -1512,7 +1512,8 @@ class Booster(object):
return ctypes2buffer(cptr, length.value) return ctypes2buffer(cptr, length.value)
def load_model(self, fname): def load_model(self, fname):
"""Load the model from a file, local or as URI. """Load the model from a file or bytearray. Path to file can be local
or as an URI.
The model is loaded from an XGBoost format which is universal among the The model is loaded from an XGBoost format which is universal among the
various XGBoost interfaces. Auxiliary attributes of the Python Booster various XGBoost interfaces. Auxiliary attributes of the Python Booster
@ -1530,6 +1531,12 @@ class Booster(object):
# from URL. # from URL.
_check_call(_LIB.XGBoosterLoadModel( _check_call(_LIB.XGBoosterLoadModel(
self.handle, c_str(os_fspath(fname)))) self.handle, c_str(os_fspath(fname))))
elif isinstance(fname, bytearray):
buf = fname
length = c_bst_ulong(len(buf))
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr,
length))
else: else:
raise TypeError('Unknown file type: ', fname) raise TypeError('Unknown file type: ', fname)

View File

@ -300,6 +300,13 @@ class TestModels(unittest.TestCase):
assert float(config['learner']['objective'][ assert float(config['learner']['objective'][
'reg_loss_param']['scale_pos_weight']) == 0.5 'reg_loss_param']['scale_pos_weight']) == 0.5
buf = bst.save_raw()
from_raw = xgb.Booster()
from_raw.load_model(buf)
buf_from_raw = from_raw.save_raw()
assert buf == buf_from_raw
def test_model_json_io(self): def test_model_json_io(self):
loc = locale.getpreferredencoding(False) loc = locale.getpreferredencoding(False)
model_path = 'test_model_json_io.json' model_path = 'test_model_json_io.json'