diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 7c01de4a0..672d538ea 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -333,6 +333,11 @@ public class Booster implements Serializable, KryoSerializable { * @throws XGBoostError native error */ public String[] getModelDump(String featureMap, boolean withStats) throws XGBoostError { + return getModelDump(featureMap, withStats, "text"); + } + + public String[] getModelDump(String featureMap, boolean withStats, String format) + throws XGBoostError { int statsFlag = 0; if (featureMap == null) { featureMap = ""; @@ -340,9 +345,12 @@ public class Booster implements Serializable, KryoSerializable { if (withStats) { statsFlag = 1; } + if (format == null) { + format = "text"; + } String[][] modelInfos = new String[1][]; JNIErrorHandle.checkCall( - XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos)); + XGBoostJNI.XGBoosterDumpModelEx(handle, featureMap, statsFlag, format, modelInfos)); return modelInfos[0]; } @@ -389,7 +397,8 @@ public class Booster implements Serializable, KryoSerializable { statsFlag = 1; } String[][] modelInfos = new String[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos)); + JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text", + modelInfos)); return modelInfos[0]; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 630c61647..4d0f31dd1 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -84,8 +84,8 @@ class XGBoostJNI { public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes); - public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, - String[][] out_strings); + public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats, + String format, String[][] out_strings); public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string); public final static native int XGBoosterSetAttr(long handle, String key, String value); diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 0c4a85dcc..6b027e2f5 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -623,17 +623,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterDumpModel + * Method: XGBoosterDumpModelEx * Signature: (JLjava/lang/String;I)[Ljava/lang/String; */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelEx + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jstring jformat, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; const char *fmap = jenv->GetStringUTFChars(jfmap, 0); + const char *format = jenv->GetStringUTFChars(jformat, 0); bst_ulong len = 0; char **result; - int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result); + int ret = XGBoosterDumpModelEx(handle, fmap, jwith_stats, format, &len, (const char ***) &result); jsize jlen = (jsize) len; jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 8e42eea1c..231f727ce 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -217,11 +217,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI - * Method: XGBoosterDumpModel + * Method: XGBoosterDumpModelEx * Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel - (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelEx + (JNIEnv *, jclass, jlong, jstring, jint, jstring, jobjectArray); /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index c60b2406a..f16f3ef63 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -231,6 +231,16 @@ public class BoosterImplTest { testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f); } + @Test + public void testDumpModelJson() throws XGBoostError { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + Booster booster = trainBooster(trainMat, testMat); + String[] dump = booster.getModelDump("", false, "json"); + TestCase.assertEquals(" { \"nodeid\":", dump[0].substring(0, 13)); + } + @Test public void testFastHistoDepthwiseMaxDepth() throws XGBoostError { DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");