Add JSON model dump functionality (#3603)

* Add JSON model dump functionality

* Fix lint
This commit is contained in:
Grace Lam 2018-08-17 16:18:43 -07:00 committed by Philip Hyunsu Cho
parent b53a5a262c
commit 993e62b9e7

View File

@ -1280,25 +1280,35 @@ class Booster(object):
ptr = (ctypes.c_char * len(buf)).from_buffer(buf) ptr = (ctypes.c_char * len(buf)).from_buffer(buf)
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)) _check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
def dump_model(self, fout, fmap='', with_stats=False): def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
""" """
Dump model into a text file. Dump model into a text or JSON file.
Parameters Parameters
---------- ----------
foout : string fout : string
Output file name. Output file name.
fmap : string, optional fmap : string, optional
Name of the file containing feature map names. Name of the file containing feature map names.
with_stats : bool (optional) with_stats : bool, optional
Controls whether the split statistics are output. Controls whether the split statistics are output.
dump_format : string, optional
Format of model dump file. Can be 'text' or 'json'.
""" """
if isinstance(fout, STRING_TYPES): if isinstance(fout, STRING_TYPES):
fout = open(fout, 'w') fout = open(fout, 'w')
need_close = True need_close = True
else: else:
need_close = False need_close = False
ret = self.get_dump(fmap, with_stats) ret = self.get_dump(fmap, with_stats, dump_format)
if dump_format == 'json':
fout.write('[\n')
for i in range(len(ret)):
fout.write(ret[i])
if i < len(ret) - 1:
fout.write(",\n")
fout.write('\n]')
else:
for i in range(len(ret)): for i in range(len(ret)):
fout.write('booster[{}]:\n'.format(i)) fout.write('booster[{}]:\n'.format(i))
fout.write(ret[i]) fout.write(ret[i])
@ -1307,9 +1317,17 @@ class Booster(object):
def get_dump(self, fmap='', with_stats=False, dump_format="text"): def get_dump(self, fmap='', with_stats=False, dump_format="text"):
""" """
Returns the dump the model as a list of strings. Returns the model dump as a list of strings.
"""
Parameters
----------
fmap : string, optional
Name of the file containing feature map names.
with_stats : bool, optional
Controls whether the split statistics are output.
dump_format : string, optional
Format of model dump. Can be 'text' or 'json'.
"""
length = c_bst_ulong() length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)() sarr = ctypes.POINTER(ctypes.c_char_p)()
if self.feature_names is not None and fmap == '': if self.feature_names is not None and fmap == '':