JSON configuration IO. (#5111)

* Add saving/loading JSON configuration.
* Implement Python pickle interface with new IO routines.
* Basic tests for training continuation.
This commit is contained in:
Jiaming Yuan
2019-12-15 17:31:53 +08:00
committed by GitHub
parent 5aa007d7b2
commit 3136185bc5
24 changed files with 761 additions and 390 deletions

View File

@@ -1076,28 +1076,47 @@ class Booster(object):
self.handle = ctypes.c_void_p()
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(len(cache)),
ctypes.byref(self.handle)))
self.set_param({'seed': 0})
self.set_param(params or {})
if (params is not None) and ('booster' in params):
self.booster = params['booster']
else:
self.booster = 'gbtree'
if model_file is not None:
if isinstance(model_file, Booster):
assert self.handle is not None
# We use the pickle interface for getting memory snapshot from
# another model, and load the snapshot with this booster.
state = model_file.__getstate__()
handle = state['handle']
del state['handle']
ptr = (ctypes.c_char * len(handle)).from_buffer(handle)
length = c_bst_ulong(len(handle))
_check_call(
_LIB.XGBoosterUnserializeFromBuffer(self.handle, ptr, length))
self.__dict__.update(state)
elif isinstance(model_file, (STRING_TYPES, os_PathLike)):
self.load_model(model_file)
elif model_file is None:
pass
else:
raise TypeError('Unknown type:', model_file)
def __del__(self):
if self.handle is not None:
if hasattr(self, 'handle') and self.handle is not None:
_check_call(_LIB.XGBoosterFree(self.handle))
self.handle = None
def __getstate__(self):
# can't pickle ctypes pointers
# put model content in bytearray
# can't pickle ctypes pointers, put model content in bytearray
this = self.__dict__.copy()
handle = this['handle']
if handle is not None:
raw = self.save_raw()
this["handle"] = raw
length = c_bst_ulong()
cptr = ctypes.POINTER(ctypes.c_char)()
_check_call(_LIB.XGBoosterSerializeToBuffer(self.handle,
ctypes.byref(length),
ctypes.byref(cptr)))
buf = ctypes2buffer(cptr, length.value)
this["handle"] = buf
return this
def __setstate__(self, state):
@@ -1107,18 +1126,44 @@ class Booster(object):
buf = handle
dmats = c_array(ctypes.c_void_p, [])
handle = ctypes.c_void_p()
_check_call(_LIB.XGBoosterCreate(dmats, c_bst_ulong(0), ctypes.byref(handle)))
_check_call(_LIB.XGBoosterCreate(
dmats, c_bst_ulong(0), ctypes.byref(handle)))
length = c_bst_ulong(len(buf))
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
_check_call(_LIB.XGBoosterLoadModelFromBuffer(handle, ptr, length))
_check_call(
_LIB.XGBoosterUnserializeFromBuffer(handle, ptr, length))
state['handle'] = handle
self.__dict__.update(state)
def save_config(self):
'''Output internal parameter configuration of Booster as a JSON
string.'''
json_string = ctypes.c_char_p()
length = c_bst_ulong()
_check_call(_LIB.XGBoosterSaveJsonConfig(
self.handle,
ctypes.byref(length),
ctypes.byref(json_string)))
json_string = json_string.value.decode()
return json_string
def load_config(self, config):
'''Load configuration returned by `save_config`.'''
assert isinstance(config, str)
_check_call(_LIB.XGBoosterLoadJsonConfig(
self.handle,
c_str(config)))
def __copy__(self):
return self.__deepcopy__(None)
def __deepcopy__(self, _):
return Booster(model_file=self.save_raw())
'''Return a copy of booster. Caches for DMatrix are not copied so continue
training on copied booster will result in lower performance and
slightly different result.
'''
return Booster(model_file=self)
def copy(self):
"""Copy the booster object.
@@ -1451,20 +1496,22 @@ class Booster(object):
def save_model(self, fname):
"""Save the model to a file.
The model is saved in an XGBoost internal binary format which is
universal among the various XGBoost interfaces. Auxiliary attributes of
the Python Booster object (such as feature_names) will not be saved.
To preserve all attributes, pickle the Booster object.
The model is saved in an XGBoost internal format which is universal
among the various XGBoost interfaces. Auxiliary attributes of the
Python Booster object (such as feature_names) will not be saved. To
preserve all attributes, pickle the Booster object.
Parameters
----------
fname : string or os.PathLike
Output file name
"""
if isinstance(fname, (STRING_TYPES, os_PathLike)): # assume file name
_check_call(_LIB.XGBoosterSaveModel(self.handle, c_str(os_fspath(fname))))
_check_call(_LIB.XGBoosterSaveModel(
self.handle, c_str(os_fspath(fname))))
else:
raise TypeError("fname must be a string")
raise TypeError("fname must be a string or os_PathLike")
def save_raw(self):
"""Save the model to a in memory buffer representation
@@ -1481,26 +1528,26 @@ class Booster(object):
return ctypes2buffer(cptr, length.value)
def load_model(self, fname):
"""Load the model from a file.
"""Load the model from a file, local or as URI.
The model is loaded from an XGBoost internal binary format which is
universal among the various XGBoost interfaces. Auxiliary attributes of
the Python Booster object (such as feature_names) will not be loaded.
To preserve all attributes, pickle the Booster object.
The model is loaded from an XGBoost format which is universal among the
various XGBoost interfaces. Auxiliary attributes of the Python Booster
object (such as feature_names) will not be loaded. To preserve all
attributes, pickle the Booster object.
Parameters
----------
fname : string, os.PathLike, or a memory buffer
Input file name or memory buffer(see also save_raw)
"""
if isinstance(fname, (STRING_TYPES, os_PathLike)):
# assume file name, cannot use os.path.exist to check, file can be from URL.
_check_call(_LIB.XGBoosterLoadModel(self.handle, c_str(os_fspath(fname))))
# assume file name, cannot use os.path.exist to check, file can be
# from URL.
_check_call(_LIB.XGBoosterLoadModel(
self.handle, c_str(os_fspath(fname))))
else:
buf = fname
length = c_bst_ulong(len(buf))
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
raise TypeError('Unknown file type: ', fname)
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
"""Dump model into a text or JSON file.

View File

@@ -34,9 +34,8 @@ def _train_internal(params, dtrain,
num_parallel_tree = 1
if xgb_model is not None:
if not isinstance(xgb_model, STRING_TYPES):
xgb_model = xgb_model.save_raw()
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
bst = Booster(params, [dtrain] + [d[0] for d in evals],
model_file=xgb_model)
nboost = len(bst.get_dump())
_params = dict(params) if isinstance(params, list) else params