[jvm-packages] support specified feature names when getModelDump and getFeatureScore (#3733)

* [jvm-packages] support specified feature names for jvm when get ModelDump and get FeatureScore (#3725)

* typo and style fix
This commit is contained in:
zengxy
2018-10-05 00:05:42 +08:00
committed by Nan Zhu
parent 34522d56f0
commit 9e73087324
6 changed files with 161 additions and 3 deletions

View File

@@ -369,15 +369,68 @@ public class Booster implements Serializable, KryoSerializable {
return modelInfos[0];
}
/**
* Get the dump of the model as a string array with specified feature names.
*
* @param featureNames Names of the features.
* @return dumped model information
* @throws XGBoostError
*/
public String[] getModelDump(String[] featureNames, boolean withStats) throws XGBoostError {
return getModelDump(featureNames, withStats, "text");
}
public String[] getModelDump(String[] featureNames, boolean withStats, String format)
throws XGBoostError {
int statsFlag = 0;
if (withStats) {
statsFlag = 1;
}
if (format == null) {
format = "text";
}
String[][] modelInfos = new String[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelExWithFeatures(
handle, featureNames, statsFlag, format, modelInfos));
return modelInfos[0];
}
/**
* Get importance of each feature with specified feature names.
*
* @return featureScoreMap key: feature name, value: feature importance score, can be nill.
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore(String[] featureNames) throws XGBoostError {
String[] modelInfos = getModelDump(featureNames, false);
Map<String, Integer> featureScore = new HashMap<>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");
if (array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if (featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
} else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* Get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score, can be nill
* @return featureScoreMap key: feature index, value: feature importance score, can be nill
* @throws XGBoostError native error
*/
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
String[] modelInfos = getModelDump(featureMap, false);
Map<String, Integer> featureScore = new HashMap<String, Integer>();
Map<String, Integer> featureScore = new HashMap<>();
for (String tree : modelInfos) {
for (String node : tree.split("\n")) {
String[] array = node.split("\\[");

View File

@@ -111,6 +111,9 @@ class XGBoostJNI {
public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats,
String format, String[][] out_strings);
public final static native int XGBoosterDumpModelExWithFeatures(
long handle, String[] feature_names, int with_stats, String format, String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);

View File

@@ -187,16 +187,42 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
booster.getModelDump(featureMap, withStats, format)
}
/**
* Dump model as Array of string with specified feature names.
*
* @param featureNames Names of features.
*/
@throws(classOf[XGBoostError])
def getModelDump(featureNames: Array[String]): Array[String] = {
booster.getModelDump(featureNames, false, "text")
}
def getModelDump(featureNames: Array[String], withStats: Boolean, format: String)
: Array[String] = {
booster.getModelDump(featureNames, withStats, format)
}
/**
* Get importance of each feature
*
* @return featureMap key: feature index, value: feature importance score
* @return featureScoreMap key: feature index, value: feature importance score
*/
@throws(classOf[XGBoostError])
def getFeatureScore(featureMap: String = null): mutable.Map[String, Integer] = {
booster.getFeatureScore(featureMap).asScala
}
/**
* Get importance of each feature with specified feature names.
*
* @return featureScoreMap key: feature name, value: feature importance score
*/
@throws(classOf[XGBoostError])
def getFeatureScore(featureNames: Array[String]): mutable.Map[String, Integer] = {
booster.getFeatureScore(featureNames).asScala
}
def getVersion: Int = booster.getVersion
def toByteArray: Array[Byte] = {