From 9e73087324f72d8695e54da96e5e133769e662ca Mon Sep 17 00:00:00 2001 From: zengxy Date: Fri, 5 Oct 2018 00:05:42 +0800 Subject: [PATCH] [jvm-packages] support specified feature names when getModelDump and getFeatureScore (#3733) * [jvm-packages] support specified feature names for jvm when get ModelDump and get FeatureScore (#3725) * typo and style fix --- .../java/ml/dmlc/xgboost4j/java/Booster.java | 57 ++++++++++++++++++- .../ml/dmlc/xgboost4j/java/XGBoostJNI.java | 3 + .../ml/dmlc/xgboost4j/scala/Booster.scala | 28 ++++++++- .../xgboost4j/src/native/xgboost4j.cpp | 50 ++++++++++++++++ jvm-packages/xgboost4j/src/native/xgboost4j.h | 8 +++ .../dmlc/xgboost4j/java/BoosterImplTest.java | 18 ++++++ 6 files changed, 161 insertions(+), 3 deletions(-) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 0faa6bb58..851a1133e 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -369,15 +369,68 @@ public class Booster implements Serializable, KryoSerializable { return modelInfos[0]; } + /** + * Get the dump of the model as a string array with specified feature names. + * + * @param featureNames Names of the features. + * @return dumped model information + * @throws XGBoostError + */ + public String[] getModelDump(String[] featureNames, boolean withStats) throws XGBoostError { + return getModelDump(featureNames, withStats, "text"); + } + + public String[] getModelDump(String[] featureNames, boolean withStats, String format) + throws XGBoostError { + int statsFlag = 0; + if (withStats) { + statsFlag = 1; + } + if (format == null) { + format = "text"; + } + String[][] modelInfos = new String[1][]; + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelExWithFeatures( + handle, featureNames, statsFlag, format, modelInfos)); + return modelInfos[0]; + } + + /** + * Get importance of each feature with specified feature names. + * + * @return featureScoreMap key: feature name, value: feature importance score, can be nill. + * @throws XGBoostError native error + */ + public Map getFeatureScore(String[] featureNames) throws XGBoostError { + String[] modelInfos = getModelDump(featureNames, false); + Map featureScore = new HashMap<>(); + for (String tree : modelInfos) { + for (String node : tree.split("\n")) { + String[] array = node.split("\\["); + if (array.length == 1) { + continue; + } + String fid = array[1].split("\\]")[0]; + fid = fid.split("<")[0]; + if (featureScore.containsKey(fid)) { + featureScore.put(fid, 1 + featureScore.get(fid)); + } else { + featureScore.put(fid, 1); + } + } + } + return featureScore; + } + /** * Get importance of each feature * - * @return featureMap key: feature index, value: feature importance score, can be nill + * @return featureScoreMap key: feature index, value: feature importance score, can be nill * @throws XGBoostError native error */ public Map getFeatureScore(String featureMap) throws XGBoostError { String[] modelInfos = getModelDump(featureMap, false); - Map featureScore = new HashMap(); + Map featureScore = new HashMap<>(); for (String tree : modelInfos) { for (String node : tree.split("\n")) { String[] array = node.split("\\["); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 6922f26b0..d20a27ba1 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -111,6 +111,9 @@ class XGBoostJNI { public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats, String format, String[][] out_strings); + public final static native int XGBoosterDumpModelExWithFeatures( + long handle, String[] feature_names, int with_stats, String format, String[][] out_strings); + public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string); public final static native int XGBoosterSetAttr(long handle, String key, String value); public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version); diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala index c9013d3c7..60b98f867 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -187,16 +187,42 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster) booster.getModelDump(featureMap, withStats, format) } + /** + * Dump model as Array of string with specified feature names. + * + * @param featureNames Names of features. + */ + @throws(classOf[XGBoostError]) + def getModelDump(featureNames: Array[String]): Array[String] = { + booster.getModelDump(featureNames, false, "text") + } + + def getModelDump(featureNames: Array[String], withStats: Boolean, format: String) + : Array[String] = { + booster.getModelDump(featureNames, withStats, format) + } + + /** * Get importance of each feature * - * @return featureMap key: feature index, value: feature importance score + * @return featureScoreMap key: feature index, value: feature importance score */ @throws(classOf[XGBoostError]) def getFeatureScore(featureMap: String = null): mutable.Map[String, Integer] = { booster.getFeatureScore(featureMap).asScala } + /** + * Get importance of each feature with specified feature names. + * + * @return featureScoreMap key: feature name, value: feature importance score + */ + @throws(classOf[XGBoostError]) + def getFeatureScore(featureNames: Array[String]): mutable.Map[String, Integer] = { + booster.getFeatureScore(featureNames).asScala + } + def getVersion: Int = booster.getVersion def toByteArray: Array[Byte] = { diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index de0383143..02ab0b228 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -656,6 +656,56 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel return ret; } +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterDumpModelExWithFeatures + * Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelExWithFeatures + (JNIEnv *jenv, jclass jcls, jlong jhandle, jobjectArray jfeature_names, jint jwith_stats, + jstring jformat, jobjectArray jout) { + + BoosterHandle handle = (BoosterHandle) jhandle; + bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeature_names); + + std::vector feature_names; + std::vector feature_names_char; + + std::string feature_type_q = "q"; + std::vector feature_types_char; + + for (bst_ulong i = 0; i < feature_num; ++i) { + jstring jfeature_name = (jstring)jenv->GetObjectArrayElement(jfeature_names, i); + const char *s = jenv->GetStringUTFChars(jfeature_name, 0); + feature_names.push_back(std::string(s, jenv->GetStringLength(jfeature_name))); + if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature_name, s); + if (feature_names.back().length() == 0) feature_names.pop_back(); + } + + for (size_t i = 0; i < feature_names.size(); ++i) { + feature_names_char.push_back(&feature_names[i][0]); + feature_types_char.push_back(&feature_type_q[0]); + } + + const char *format = jenv->GetStringUTFChars(jformat, 0); + bst_ulong len = 0; + char **result; + + int ret = XGBoosterDumpModelExWithFeatures(handle, feature_num, + (const char **) dmlc::BeginPtr(feature_names_char), + (const char **) dmlc::BeginPtr(feature_types_char), + jwith_stats, format, &len, (const char ***) &result); + + jsize jlen = (jsize) len; + jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); + for(int i=0 ; iSetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i])); + } + jenv->SetObjectArrayElement(jout, 0, jinfos); + + return ret; +} + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterLoadRabitCheckpoint diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index 231f727ce..3994d7825 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -223,6 +223,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelEx (JNIEnv *, jclass, jlong, jstring, jint, jstring, jobjectArray); +/* + * Class: ml_dmlc_xgboost4j_java_XGBoostJNI + * Method: XGBoosterDumpModelExWithFeatures + * Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelExWithFeatures + (JNIEnv *, jclass, jlong, jobjectArray, jint, jstring, jobjectArray); + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterGetAttr diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 1a7ad9d68..85d2b61d2 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -271,6 +271,24 @@ public class BoosterImplTest { Booster booster = trainBooster(trainMat, testMat); String[] dump = booster.getModelDump("", false, "json"); TestCase.assertEquals(" { \"nodeid\":", dump[0].substring(0, 13)); + + // test with specified feature names + String[] featureNames = new String[126]; + for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; + dump = booster.getModelDump(featureNames, false, "json"); + TestCase.assertTrue(dump[0].contains("test_feature_name_")); + } + + @Test + public void testGetFeatureImportance() throws XGBoostError { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + Booster booster = trainBooster(trainMat, testMat); + String[] featureNames = new String[126]; + for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i; + Map scoreMap = booster.getFeatureScore(featureNames); + for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_")); } @Test