JSON configuration IO. (#5111)
* Add saving/loading JSON configuration. * Implement Python pickle interface with new IO routines. * Basic tests for training continuation.
This commit is contained in:
@@ -1076,28 +1076,47 @@ class Booster(object):
|
||||
self.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
|
||||
ctypes.byref(self.handle)))
|
||||
self.set_param({'seed': 0})
|
||||
self.set_param(params or {})
|
||||
if (params is not None) and ('booster' in params):
|
||||
self.booster = params['booster']
|
||||
else:
|
||||
self.booster = 'gbtree'
|
||||
if model_file is not None:
|
||||
if isinstance(model_file, Booster):
|
||||
assert self.handle is not None
|
||||
# We use the pickle interface for getting memory snapshot from
|
||||
# another model, and load the snapshot with this booster.
|
||||
state = model_file.__getstate__()
|
||||
handle = state['handle']
|
||||
del state['handle']
|
||||
ptr = (ctypes.c_char * len(handle)).from_buffer(handle)
|
||||
length = c_bst_ulong(len(handle))
|
||||
_check_call(
|
||||
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
|
||||
self.__dict__.update(state)
|
||||
elif isinstance(model_file, (STRING_TYPES, os_PathLike)):
|
||||
self.load_model(model_file)
|
||||
elif model_file is None:
|
||||
pass
|
||||
else:
|
||||
raise TypeError('Unknown type:', model_file)
|
||||
|
||||
def __del__(self):
|
||||
if self.handle is not None:
|
||||
if hasattr(self, 'handle') and self.handle is not None:
|
||||
_check_call(_LIB.XGBoosterFree(self.handle))
|
||||
self.handle = None
|
||||
|
||||
def __getstate__(self):
|
||||
# can't pickle ctypes pointers
|
||||
# put model content in bytearray
|
||||
# can't pickle ctypes pointers, put model content in bytearray
|
||||
this = self.__dict__.copy()
|
||||
handle = this['handle']
|
||||
if handle is not None:
|
||||
raw = self.save_raw()
|
||||
this["handle"] = raw
|
||||
length = c_bst_ulong()
|
||||
cptr = ctypes.POINTER(ctypes.c_char)()
|
||||
_check_call(_LIB.XGBoosterSerializeToBuffer(self.handle,
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(cptr)))
|
||||
buf = ctypes2buffer(cptr, length.value)
|
||||
this["handle"] = buf
|
||||
return this
|
||||
|
||||
def __setstate__(self, state):
|
||||
@@ -1107,18 +1126,44 @@ class Booster(object):
|
||||
buf = handle
|
||||
dmats = c_array(ctypes.c_void_p, [])
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(0), ctypes.byref(handle)))
|
||||
_check_call(_LIB.XGBoosterCreate(
|
||||
dmats, c_bst_ulong(0), ctypes.byref(handle)))
|
||||
length = c_bst_ulong(len(buf))
|
||||
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||
_check_call(_LIB.XGBoosterLoadModelFromBuffer(handle, ptr, length))
|
||||
_check_call(
|
||||
_LIB.XGBoosterUnserializeFromBuffer(handle, ptr, length))
|
||||
state['handle'] = handle
|
||||
self.__dict__.update(state)
|
||||
|
||||
def save_config(self):
|
||||
'''Output internal parameter configuration of Booster as a JSON
|
||||
string.'''
|
||||
json_string = ctypes.c_char_p()
|
||||
length = c_bst_ulong()
|
||||
_check_call(_LIB.XGBoosterSaveJsonConfig(
|
||||
self.handle,
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(json_string)))
|
||||
json_string = json_string.value.decode()
|
||||
return json_string
|
||||
|
||||
def load_config(self, config):
|
||||
'''Load configuration returned by `save_config`.'''
|
||||
assert isinstance(config, str)
|
||||
_check_call(_LIB.XGBoosterLoadJsonConfig(
|
||||
self.handle,
|
||||
c_str(config)))
|
||||
|
||||
def __copy__(self):
|
||||
return self.__deepcopy__(None)
|
||||
|
||||
def __deepcopy__(self, _):
|
||||
return Booster(model_file=self.save_raw())
|
||||
'''Return a copy of booster. Caches for DMatrix are not copied so continue
|
||||
training on copied booster will result in lower performance and
|
||||
slightly different result.
|
||||
|
||||
'''
|
||||
return Booster(model_file=self)
|
||||
|
||||
def copy(self):
|
||||
"""Copy the booster object.
|
||||
@@ -1451,20 +1496,22 @@ class Booster(object):
|
||||
def save_model(self, fname):
|
||||
"""Save the model to a file.
|
||||
|
||||
The model is saved in an XGBoost internal binary format which is
|
||||
universal among the various XGBoost interfaces. Auxiliary attributes of
|
||||
the Python Booster object (such as feature_names) will not be saved.
|
||||
To preserve all attributes, pickle the Booster object.
|
||||
The model is saved in an XGBoost internal format which is universal
|
||||
among the various XGBoost interfaces. Auxiliary attributes of the
|
||||
Python Booster object (such as feature_names) will not be saved. To
|
||||
preserve all attributes, pickle the Booster object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string or os.PathLike
|
||||
Output file name
|
||||
|
||||
"""
|
||||
if isinstance(fname, (STRING_TYPES, os_PathLike)): # assume file name
|
||||
_check_call(_LIB.XGBoosterSaveModel(self.handle, c_str(os_fspath(fname))))
|
||||
_check_call(_LIB.XGBoosterSaveModel(
|
||||
self.handle, c_str(os_fspath(fname))))
|
||||
else:
|
||||
raise TypeError("fname must be a string")
|
||||
raise TypeError("fname must be a string or os_PathLike")
|
||||
|
||||
def save_raw(self):
|
||||
"""Save the model to a in memory buffer representation
|
||||
@@ -1481,26 +1528,26 @@ class Booster(object):
|
||||
return ctypes2buffer(cptr, length.value)
|
||||
|
||||
def load_model(self, fname):
|
||||
"""Load the model from a file.
|
||||
"""Load the model from a file, local or as URI.
|
||||
|
||||
The model is loaded from an XGBoost internal binary format which is
|
||||
universal among the various XGBoost interfaces. Auxiliary attributes of
|
||||
the Python Booster object (such as feature_names) will not be loaded.
|
||||
To preserve all attributes, pickle the Booster object.
|
||||
The model is loaded from an XGBoost format which is universal among the
|
||||
various XGBoost interfaces. Auxiliary attributes of the Python Booster
|
||||
object (such as feature_names) will not be loaded. To preserve all
|
||||
attributes, pickle the Booster object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : string, os.PathLike, or a memory buffer
|
||||
Input file name or memory buffer(see also save_raw)
|
||||
|
||||
"""
|
||||
if isinstance(fname, (STRING_TYPES, os_PathLike)):
|
||||
# assume file name, cannot use os.path.exist to check, file can be from URL.
|
||||
_check_call(_LIB.XGBoosterLoadModel(self.handle, c_str(os_fspath(fname))))
|
||||
# assume file name, cannot use os.path.exist to check, file can be
|
||||
# from URL.
|
||||
_check_call(_LIB.XGBoosterLoadModel(
|
||||
self.handle, c_str(os_fspath(fname))))
|
||||
else:
|
||||
buf = fname
|
||||
length = c_bst_ulong(len(buf))
|
||||
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
|
||||
raise TypeError('Unknown file type: ', fname)
|
||||
|
||||
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
|
||||
"""Dump model into a text or JSON file.
|
||||
|
||||
@@ -34,9 +34,8 @@ def _train_internal(params, dtrain,
|
||||
num_parallel_tree = 1
|
||||
|
||||
if xgb_model is not None:
|
||||
if not isinstance(xgb_model, STRING_TYPES):
|
||||
xgb_model = xgb_model.save_raw()
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals],
|
||||
model_file=xgb_model)
|
||||
nboost = len(bst.get_dump())
|
||||
|
||||
_params = dict(params) if isinstance(params, list) else params
|
||||
|
||||
Reference in New Issue
Block a user