Add JSON model dump functionality (#3603)
* Add JSON model dump functionality * Fix lint
This commit is contained in:
parent
b53a5a262c
commit
993e62b9e7
@ -1280,36 +1280,54 @@ class Booster(object):
|
|||||||
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
|
||||||
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
|
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
|
||||||
|
|
||||||
def dump_model(self, fout, fmap='', with_stats=False):
|
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
|
||||||
"""
|
"""
|
||||||
Dump model into a text file.
|
Dump model into a text or JSON file.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
foout : string
|
fout : string
|
||||||
Output file name.
|
Output file name.
|
||||||
fmap : string, optional
|
fmap : string, optional
|
||||||
Name of the file containing feature map names.
|
Name of the file containing feature map names.
|
||||||
with_stats : bool (optional)
|
with_stats : bool, optional
|
||||||
Controls whether the split statistics are output.
|
Controls whether the split statistics are output.
|
||||||
|
dump_format : string, optional
|
||||||
|
Format of model dump file. Can be 'text' or 'json'.
|
||||||
"""
|
"""
|
||||||
if isinstance(fout, STRING_TYPES):
|
if isinstance(fout, STRING_TYPES):
|
||||||
fout = open(fout, 'w')
|
fout = open(fout, 'w')
|
||||||
need_close = True
|
need_close = True
|
||||||
else:
|
else:
|
||||||
need_close = False
|
need_close = False
|
||||||
ret = self.get_dump(fmap, with_stats)
|
ret = self.get_dump(fmap, with_stats, dump_format)
|
||||||
for i in range(len(ret)):
|
if dump_format == 'json':
|
||||||
fout.write('booster[{}]:\n'.format(i))
|
fout.write('[\n')
|
||||||
fout.write(ret[i])
|
for i in range(len(ret)):
|
||||||
|
fout.write(ret[i])
|
||||||
|
if i < len(ret) - 1:
|
||||||
|
fout.write(",\n")
|
||||||
|
fout.write('\n]')
|
||||||
|
else:
|
||||||
|
for i in range(len(ret)):
|
||||||
|
fout.write('booster[{}]:\n'.format(i))
|
||||||
|
fout.write(ret[i])
|
||||||
if need_close:
|
if need_close:
|
||||||
fout.close()
|
fout.close()
|
||||||
|
|
||||||
def get_dump(self, fmap='', with_stats=False, dump_format="text"):
|
def get_dump(self, fmap='', with_stats=False, dump_format="text"):
|
||||||
"""
|
"""
|
||||||
Returns the dump the model as a list of strings.
|
Returns the model dump as a list of strings.
|
||||||
"""
|
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fmap : string, optional
|
||||||
|
Name of the file containing feature map names.
|
||||||
|
with_stats : bool, optional
|
||||||
|
Controls whether the split statistics are output.
|
||||||
|
dump_format : string, optional
|
||||||
|
Format of model dump. Can be 'text' or 'json'.
|
||||||
|
"""
|
||||||
length = c_bst_ulong()
|
length = c_bst_ulong()
|
||||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||||
if self.feature_names is not None and fmap == '':
|
if self.feature_names is not None and fmap == '':
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user