diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 5f1cb8465..86b70a01a 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1280,36 +1280,54 @@ class Booster(object): ptr = (ctypes.c_char * len(buf)).from_buffer(buf) _check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)) - def dump_model(self, fout, fmap='', with_stats=False): + def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"): """ - Dump model into a text file. + Dump model into a text or JSON file. Parameters ---------- - foout : string + fout : string Output file name. fmap : string, optional Name of the file containing feature map names. - with_stats : bool (optional) + with_stats : bool, optional Controls whether the split statistics are output. + dump_format : string, optional + Format of model dump file. Can be 'text' or 'json'. """ if isinstance(fout, STRING_TYPES): fout = open(fout, 'w') need_close = True else: need_close = False - ret = self.get_dump(fmap, with_stats) - for i in range(len(ret)): - fout.write('booster[{}]:\n'.format(i)) - fout.write(ret[i]) + ret = self.get_dump(fmap, with_stats, dump_format) + if dump_format == 'json': + fout.write('[\n') + for i in range(len(ret)): + fout.write(ret[i]) + if i < len(ret) - 1: + fout.write(",\n") + fout.write('\n]') + else: + for i in range(len(ret)): + fout.write('booster[{}]:\n'.format(i)) + fout.write(ret[i]) if need_close: fout.close() def get_dump(self, fmap='', with_stats=False, dump_format="text"): """ - Returns the dump the model as a list of strings. - """ + Returns the model dump as a list of strings. + Parameters + ---------- + fmap : string, optional + Name of the file containing feature map names. + with_stats : bool, optional + Controls whether the split statistics are output. + dump_format : string, optional + Format of model dump. Can be 'text' or 'json'. + """ length = c_bst_ulong() sarr = ctypes.POINTER(ctypes.c_char_p)() if self.feature_names is not None and fmap == '':