* Change Booster dump from XGBoosterDumpModel to XGBoosterDumpModelEx Allows exposing multiple formatting options of model dumping.
This commit is contained in:
parent
c441d0916e
commit
d3b866e3fd
@ -333,6 +333,11 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
* @throws XGBoostError native error
|
* @throws XGBoostError native error
|
||||||
*/
|
*/
|
||||||
public String[] getModelDump(String featureMap, boolean withStats) throws XGBoostError {
|
public String[] getModelDump(String featureMap, boolean withStats) throws XGBoostError {
|
||||||
|
return getModelDump(featureMap, withStats, "text");
|
||||||
|
}
|
||||||
|
|
||||||
|
public String[] getModelDump(String featureMap, boolean withStats, String format)
|
||||||
|
throws XGBoostError {
|
||||||
int statsFlag = 0;
|
int statsFlag = 0;
|
||||||
if (featureMap == null) {
|
if (featureMap == null) {
|
||||||
featureMap = "";
|
featureMap = "";
|
||||||
@ -340,9 +345,12 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
if (withStats) {
|
if (withStats) {
|
||||||
statsFlag = 1;
|
statsFlag = 1;
|
||||||
}
|
}
|
||||||
|
if (format == null) {
|
||||||
|
format = "text";
|
||||||
|
}
|
||||||
String[][] modelInfos = new String[1][];
|
String[][] modelInfos = new String[1][];
|
||||||
JNIErrorHandle.checkCall(
|
JNIErrorHandle.checkCall(
|
||||||
XGBoostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
|
XGBoostJNI.XGBoosterDumpModelEx(handle, featureMap, statsFlag, format, modelInfos));
|
||||||
return modelInfos[0];
|
return modelInfos[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -389,7 +397,8 @@ public class Booster implements Serializable, KryoSerializable {
|
|||||||
statsFlag = 1;
|
statsFlag = 1;
|
||||||
}
|
}
|
||||||
String[][] modelInfos = new String[1][];
|
String[][] modelInfos = new String[1][];
|
||||||
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
|
||||||
|
modelInfos));
|
||||||
return modelInfos[0];
|
return modelInfos[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -84,8 +84,8 @@ class XGBoostJNI {
|
|||||||
|
|
||||||
public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes);
|
public final static native int XGBoosterGetModelRaw(long handle, byte[][] out_bytes);
|
||||||
|
|
||||||
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
|
public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats,
|
||||||
String[][] out_strings);
|
String format, String[][] out_strings);
|
||||||
|
|
||||||
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
||||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||||
|
|||||||
@ -623,17 +623,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGBoosterDumpModel
|
* Method: XGBoosterDumpModelEx
|
||||||
* Signature: (JLjava/lang/String;I)[Ljava/lang/String;
|
* Signature: (JLjava/lang/String;I)[Ljava/lang/String;
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelEx
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jstring jformat, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
|
const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
|
||||||
|
const char *format = jenv->GetStringUTFChars(jformat, 0);
|
||||||
bst_ulong len = 0;
|
bst_ulong len = 0;
|
||||||
char **result;
|
char **result;
|
||||||
|
|
||||||
int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result);
|
int ret = XGBoosterDumpModelEx(handle, fmap, jwith_stats, format, &len, (const char ***) &result);
|
||||||
|
|
||||||
jsize jlen = (jsize) len;
|
jsize jlen = (jsize) len;
|
||||||
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
|
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
|
||||||
|
|||||||
@ -217,11 +217,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR
|
|||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
* Method: XGBoosterDumpModel
|
* Method: XGBoosterDumpModelEx
|
||||||
* Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
|
* Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
|
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelEx
|
||||||
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
|
(JNIEnv *, jclass, jlong, jstring, jint, jstring, jobjectArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||||
|
|||||||
@ -231,6 +231,16 @@ public class BoosterImplTest {
|
|||||||
testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f);
|
testWithFastHisto(trainMat, watches, 10, paramMap, 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testDumpModelJson() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
Booster booster = trainBooster(trainMat, testMat);
|
||||||
|
String[] dump = booster.getModelDump("", false, "json");
|
||||||
|
TestCase.assertEquals(" { \"nodeid\":", dump[0].substring(0, 13));
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFastHistoDepthwiseMaxDepth() throws XGBoostError {
|
public void testFastHistoDepthwiseMaxDepth() throws XGBoostError {
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user