@@ -1052,7 +1052,7 @@ class Booster(object):
|
||||
_check_call(
|
||||
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
|
||||
self.__dict__.update(state)
|
||||
elif isinstance(model_file, (STRING_TYPES, os_PathLike)):
|
||||
elif isinstance(model_file, (STRING_TYPES, os_PathLike, bytearray)):
|
||||
self.load_model(model_file)
|
||||
elif model_file is None:
|
||||
pass
|
||||
@@ -1512,7 +1512,8 @@ class Booster(object):
|
||||
return ctypes2buffer(cptr, length.value)
|
||||
|
||||
def load_model(self, fname):
|
||||
"""Load the model from a file, local or as URI.
|
||||
"""Load the model from a file or bytearray. Path to file can be local
|
||||
or as an URI.
|
||||
|
||||
The model is loaded from an XGBoost format which is universal among the
|
||||
various XGBoost interfaces. Auxiliary attributes of the Python Booster
|
||||
@@ -1530,6 +1531,12 @@ class Booster(object):
|
||||
# from URL.
|
||||
_check_call(_LIB.XGBoosterLoadModel(
|
||||
self.handle, c_str(os_fspath(fname))))
|
||||
elif isinstance(fname, bytearray):
|
||||
buf = fname
|
||||
length = c_bst_ulong(len(buf))
|
||||
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr,
|
||||
length))
|
||||
else:
|
||||
raise TypeError('Unknown file type: ', fname)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user