Add JSON model dump functionality (#3603)

* Add JSON model dump functionality

* Fix lint
This commit is contained in:
Grace Lam 2018-08-17 16:18:43 -07:00 committed by Philip Hyunsu Cho
parent b53a5a262c
commit 993e62b9e7

View File

@ -1280,25 +1280,35 @@ class Booster(object):
ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
def dump_model(self, fout, fmap='', with_stats=False):
def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
"""
Dump model into a text file.
Dump model into a text or JSON file.
Parameters
----------
foout : string
fout : string
Output file name.
fmap : string, optional
Name of the file containing feature map names.
with_stats : bool (optional)
with_stats : bool, optional
Controls whether the split statistics are output.
dump_format : string, optional
Format of model dump file. Can be 'text' or 'json'.
"""
if isinstance(fout, STRING_TYPES):
fout = open(fout, 'w')
need_close = True
else:
need_close = False
ret = self.get_dump(fmap, with_stats)
ret = self.get_dump(fmap, with_stats, dump_format)
if dump_format == 'json':
fout.write('[\n')
for i in range(len(ret)):
fout.write(ret[i])
if i < len(ret) - 1:
fout.write(",\n")
fout.write('\n]')
else:
for i in range(len(ret)):
fout.write('booster[{}]:\n'.format(i))
fout.write(ret[i])
@ -1307,9 +1317,17 @@ class Booster(object):
def get_dump(self, fmap='', with_stats=False, dump_format="text"):
"""
Returns the dump the model as a list of strings.
"""
Returns the model dump as a list of strings.
Parameters
----------
fmap : string, optional
Name of the file containing feature map names.
with_stats : bool, optional
Controls whether the split statistics are output.
dump_format : string, optional
Format of model dump. Can be 'text' or 'json'.
"""
length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)()
if self.feature_names is not None and fmap == '':