[jvm-packages] support specified feature names when getModelDump and getFeatureScore (#3733)
* [jvm-packages] support specified feature names for jvm when get ModelDump and get FeatureScore (#3725) * typo and style fix
This commit is contained in:
parent
34522d56f0
commit
9e73087324
@ -369,15 +369,68 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the dump of the model as a string array with specified feature names.
|
||||
*
|
||||
* @param featureNames Names of the features.
|
||||
* @return dumped model information
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public String[] getModelDump(String[] featureNames, boolean withStats) throws XGBoostError {
|
||||
return getModelDump(featureNames, withStats, "text");
|
||||
}
|
||||
|
||||
public String[] getModelDump(String[] featureNames, boolean withStats, String format)
|
||||
throws XGBoostError {
|
||||
int statsFlag = 0;
|
||||
if (withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
if (format == null) {
|
||||
format = "text";
|
||||
}
|
||||
String[][] modelInfos = new String[1][];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelExWithFeatures(
|
||||
handle, featureNames, statsFlag, format, modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get importance of each feature with specified feature names.
|
||||
*
|
||||
* @return featureScoreMap key: feature name, value: feature importance score, can be nill.
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore(String[] featureNames) throws XGBoostError {
|
||||
String[] modelInfos = getModelDump(featureNames, false);
|
||||
Map<String, Integer> featureScore = new HashMap<>();
|
||||
for (String tree : modelInfos) {
|
||||
for (String node : tree.split("\n")) {
|
||||
String[] array = node.split("\\[");
|
||||
if (array.length == 1) {
|
||||
continue;
|
||||
}
|
||||
String fid = array[1].split("\\]")[0];
|
||||
fid = fid.split("<")[0];
|
||||
if (featureScore.containsKey(fid)) {
|
||||
featureScore.put(fid, 1 + featureScore.get(fid));
|
||||
} else {
|
||||
featureScore.put(fid, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return featureScore;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get importance of each feature
|
||||
*
|
||||
* @return featureMap key: feature index, value: feature importance score, can be nill
|
||||
* @return featureScoreMap key: feature index, value: feature importance score, can be nill
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore(String featureMap) throws XGBoostError {
|
||||
String[] modelInfos = getModelDump(featureMap, false);
|
||||
Map<String, Integer> featureScore = new HashMap<String, Integer>();
|
||||
Map<String, Integer> featureScore = new HashMap<>();
|
||||
for (String tree : modelInfos) {
|
||||
for (String node : tree.split("\n")) {
|
||||
String[] array = node.split("\\[");
|
||||
|
||||
@ -111,6 +111,9 @@ class XGBoostJNI {
|
||||
public final static native int XGBoosterDumpModelEx(long handle, String fmap, int with_stats,
|
||||
String format, String[][] out_strings);
|
||||
|
||||
public final static native int XGBoosterDumpModelExWithFeatures(
|
||||
long handle, String[] feature_names, int with_stats, String format, String[][] out_strings);
|
||||
|
||||
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
|
||||
public final static native int XGBoosterSetAttr(long handle, String key, String value);
|
||||
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
|
||||
|
||||
@ -187,16 +187,42 @@ class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
|
||||
booster.getModelDump(featureMap, withStats, format)
|
||||
}
|
||||
|
||||
/**
|
||||
* Dump model as Array of string with specified feature names.
|
||||
*
|
||||
* @param featureNames Names of features.
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getModelDump(featureNames: Array[String]): Array[String] = {
|
||||
booster.getModelDump(featureNames, false, "text")
|
||||
}
|
||||
|
||||
def getModelDump(featureNames: Array[String], withStats: Boolean, format: String)
|
||||
: Array[String] = {
|
||||
booster.getModelDump(featureNames, withStats, format)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get importance of each feature
|
||||
*
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
* @return featureScoreMap key: feature index, value: feature importance score
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureScore(featureMap: String = null): mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore(featureMap).asScala
|
||||
}
|
||||
|
||||
/**
|
||||
* Get importance of each feature with specified feature names.
|
||||
*
|
||||
* @return featureScoreMap key: feature name, value: feature importance score
|
||||
*/
|
||||
@throws(classOf[XGBoostError])
|
||||
def getFeatureScore(featureNames: Array[String]): mutable.Map[String, Integer] = {
|
||||
booster.getFeatureScore(featureNames).asScala
|
||||
}
|
||||
|
||||
def getVersion: Int = booster.getVersion
|
||||
|
||||
def toByteArray: Array[Byte] = {
|
||||
|
||||
@ -656,6 +656,56 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterDumpModelExWithFeatures
|
||||
* Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelExWithFeatures
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jobjectArray jfeature_names, jint jwith_stats,
|
||||
jstring jformat, jobjectArray jout) {
|
||||
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeature_names);
|
||||
|
||||
std::vector<std::string> feature_names;
|
||||
std::vector<char*> feature_names_char;
|
||||
|
||||
std::string feature_type_q = "q";
|
||||
std::vector<char*> feature_types_char;
|
||||
|
||||
for (bst_ulong i = 0; i < feature_num; ++i) {
|
||||
jstring jfeature_name = (jstring)jenv->GetObjectArrayElement(jfeature_names, i);
|
||||
const char *s = jenv->GetStringUTFChars(jfeature_name, 0);
|
||||
feature_names.push_back(std::string(s, jenv->GetStringLength(jfeature_name)));
|
||||
if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature_name, s);
|
||||
if (feature_names.back().length() == 0) feature_names.pop_back();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < feature_names.size(); ++i) {
|
||||
feature_names_char.push_back(&feature_names[i][0]);
|
||||
feature_types_char.push_back(&feature_type_q[0]);
|
||||
}
|
||||
|
||||
const char *format = jenv->GetStringUTFChars(jformat, 0);
|
||||
bst_ulong len = 0;
|
||||
char **result;
|
||||
|
||||
int ret = XGBoosterDumpModelExWithFeatures(handle, feature_num,
|
||||
(const char **) dmlc::BeginPtr(feature_names_char),
|
||||
(const char **) dmlc::BeginPtr(feature_types_char),
|
||||
jwith_stats, format, &len, (const char ***) &result);
|
||||
|
||||
jsize jlen = (jsize) len;
|
||||
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
|
||||
for(int i=0 ; i<jlen; i++) {
|
||||
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i]));
|
||||
}
|
||||
jenv->SetObjectArrayElement(jout, 0, jinfos);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterLoadRabitCheckpoint
|
||||
|
||||
@ -223,6 +223,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetModelR
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelEx
|
||||
(JNIEnv *, jclass, jlong, jstring, jint, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterDumpModelExWithFeatures
|
||||
* Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelExWithFeatures
|
||||
(JNIEnv *, jclass, jlong, jobjectArray, jint, jstring, jobjectArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGBoosterGetAttr
|
||||
|
||||
@ -271,6 +271,24 @@ public class BoosterImplTest {
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] dump = booster.getModelDump("", false, "json");
|
||||
TestCase.assertEquals(" { \"nodeid\":", dump[0].substring(0, 13));
|
||||
|
||||
// test with specified feature names
|
||||
String[] featureNames = new String[126];
|
||||
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
|
||||
dump = booster.getModelDump(featureNames, false, "json");
|
||||
TestCase.assertTrue(dump[0].contains("test_feature_name_"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGetFeatureImportance() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
Booster booster = trainBooster(trainMat, testMat);
|
||||
String[] featureNames = new String[126];
|
||||
for(int i = 0; i < 126; i++) featureNames[i] = "test_feature_name_" + i;
|
||||
Map<String, Integer> scoreMap = booster.getFeatureScore(featureNames);
|
||||
for (String fName: scoreMap.keySet()) TestCase.assertTrue(fName.startsWith("test_feature_name_"));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user