parent
7d178cbd25
commit
2bc5d8d449
@ -1052,7 +1052,7 @@ class Booster(object):
|
|||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
|
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
|
||||||
self.__dict__.update(state)
|
self.__dict__.update(state)
|
||||||
elif isinstance(model_file, (STRING_TYPES, os_PathLike)):
|
elif isinstance(model_file, (STRING_TYPES, os_PathLike, bytearray)):
|
||||||
self.load_model(model_file)
|
self.load_model(model_file)
|
||||||
elif model_file is None:
|
elif model_file is None:
|
||||||
pass
|
pass
|
||||||
@ -1512,7 +1512,8 @@ class Booster(object):
|
|||||||
return ctypes2buffer(cptr, length.value)
|
return ctypes2buffer(cptr, length.value)
|
||||||
|
|
||||||
def load_model(self, fname):
|
def load_model(self, fname):
|
||||||
"""Load the model from a file, local or as URI.
|
"""Load the model from a file or bytearray. Path to file can be local
|
||||||
|
or as an URI.
|
||||||
|
|
||||||
The model is loaded from an XGBoost format which is universal among the
|
The model is loaded from an XGBoost format which is universal among the
|
||||||
various XGBoost interfaces. Auxiliary attributes of the Python Booster
|
various XGBoost interfaces. Auxiliary attributes of the Python Booster
|
||||||
@ -1530,6 +1531,12 @@ class Booster(object):
|
|||||||
# from URL.
|
# from URL.
|
||||||
_check_call(_LIB.XGBoosterLoadModel(
|
_check_call(_LIB.XGBoosterLoadModel(
|
||||||
self.handle, c_str(os_fspath(fname))))
|
self.handle, c_str(os_fspath(fname))))
|
||||||
|
elif isinstance(fname, bytearray):
|
||||||
|
buf = fname
|
||||||
|
length = c_bst_ulong(len(buf))
|
||||||
|
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||||
|
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr,
|
||||||
|
length))
|
||||||
else:
|
else:
|
||||||
raise TypeError('Unknown file type: ', fname)
|
raise TypeError('Unknown file type: ', fname)
|
||||||
|
|
||||||
|
|||||||
@ -300,6 +300,13 @@ class TestModels(unittest.TestCase):
|
|||||||
assert float(config['learner']['objective'][
|
assert float(config['learner']['objective'][
|
||||||
'reg_loss_param']['scale_pos_weight']) == 0.5
|
'reg_loss_param']['scale_pos_weight']) == 0.5
|
||||||
|
|
||||||
|
buf = bst.save_raw()
|
||||||
|
from_raw = xgb.Booster()
|
||||||
|
from_raw.load_model(buf)
|
||||||
|
|
||||||
|
buf_from_raw = from_raw.save_raw()
|
||||||
|
assert buf == buf_from_raw
|
||||||
|
|
||||||
def test_model_json_io(self):
|
def test_model_json_io(self):
|
||||||
loc = locale.getpreferredencoding(False)
|
loc = locale.getpreferredencoding(False)
|
||||||
model_path = 'test_model_json_io.json'
|
model_path = 'test_model_json_io.json'
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user