[jvm-packages] Expose json formatted booster dumps (#2233) (#2234)

* Change Booster dump from XGBoosterDumpModel to XGBoosterDumpModelEx

Allows exposing multiple formatting options of model dumping.
This commit is contained in:
ebernhardson 2017-04-29 20:23:09 -07:00 committed by Nan Zhu
parent c441d0916e
commit d3b866e3fd
5 changed files with 31 additions and 11 deletions

View File

@ -333,6 +333,11 @@ public class Booster implements Serializable, KryoSerializable {
* @throws XGBoostError native error * @throws XGBoostError native error
*/ */
public String[] getModelDump(String featureMap, boolean withStats) throws XGBoostError { public String[] getModelDump(String featureMap, boolean withStats) throws XGBoostError {
return getModelDump(featureMap, withStats, "text");
}
public String[] getModelDump(String featureMap, boolean withStats, String format)
throws XGBoostError {
int statsFlag = 0; int statsFlag = 0;
if (featureMap == null) { if (featureMap == null) {
featureMap = ""; featureMap = "";
@ -340,9 +345,12 @@ public class Booster implements Serializable, KryoSerializable {
if (withStats) { if (withStats) {
statsFlag = 1; statsFlag = 1;
} }
if (format == null) {
format = "text";
}
String[][] modelInfos = new String[1][]; String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall( JNIErrorHandle.checkCall(
XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos)); XGBoostJNI.XGBoosterDumpModelEx(handle, featureMap, statsFlag, format, modelInfos));
return modelInfos[0]; return modelInfos[0];
} }
@ -389,7 +397,8 @@ public class Booster implements Serializable, KryoSerializable {
statsFlag = 1; statsFlag = 1;
} }
String[][] modelInfos = new String[1][]; String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos)); JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
modelInfos));
return modelInfos[0]; return modelInfos[0];
} }

View File

@ -84,8 +84,8 @@ class XGBoostJNI {
public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes); public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes);
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats,
String[][] out_strings); String format, String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string); public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value); public final static native int XGBoosterSetAttr(long handle, String key, String value);

View File

@ -623,17 +623,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterDumpModel * Method: XGBoosterDumpModelEx
* Signature: (JLjava/lang/String;I)[Ljava/lang/String; * Signature: (JLjava/lang/String;I)[Ljava/lang/String;
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelEx
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jstring jformat, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
const char *fmap = jenv->GetStringUTFChars(jfmap, 0); const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
const char *format = jenv->GetStringUTFChars(jformat, 0);
bst_ulong len = 0; bst_ulong len = 0;
char **result; char **result;
int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result); int ret = XGBoosterDumpModelEx(handle, fmap, jwith_stats, format, &len, (const char ***) &result);
jsize jlen = (jsize) len; jsize jlen = (jsize) len;
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));

View File

@ -217,11 +217,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterDumpModel * Method: XGBoosterDumpModelEx
* Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I * Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelEx
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray); (JNIEnv *, jclass, jlong, jstring, jint, jstring, jobjectArray);
/* /*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Class: ml_dmlc_xgboost4j_java_XGBoostJNI

View File

@ -231,6 +231,16 @@ public class BoosterImplTest {
testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f); testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f);
} }
@Test
public void testDumpModelJson() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
Booster booster = trainBooster(trainMat, testMat);
String[] dump = booster.getModelDump("", false, "json");
TestCase.assertEquals(" { \"nodeid\":", dump[0].substring(0, 13));
}
@Test @Test
public void testFastHistoDepthwiseMaxDepth() throws XGBoostError { public void testFastHistoDepthwiseMaxDepth() throws XGBoostError {
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");